diff --git a/build.gradle b/build.gradle index ab170a9f4..546679e67 100644 --- a/build.gradle +++ b/build.gradle @@ -126,9 +126,12 @@ dependencies { implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2' implementation group: 'commons-lang', name: 'commons-lang', version: '2.6' implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.10.0' - implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.0-rc3' - implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.0-rc3' - implementation 'software.amazon.randomcutforest:randomcutforest-core:3.0-rc3' + implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.7.0' + implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.7.0' + implementation 'software.amazon.randomcutforest:randomcutforest-core:3.7.0' + //implementation files('lib/randomcutforest-core-3.5.0.jar') + //implementation files('lib/randomcutforest-serialization-3.5.0.jar') + //implementation files('lib/randomcutforest-parkservices-3.5.0.jar') // we inherit jackson-core from opensearch core implementation "com.fasterxml.jackson.core:jackson-databind:2.14.1" @@ -402,7 +405,8 @@ testClusters.integTest { return new RegularFile() { @Override File getAsFile() { - return configurations.zipArchive.asFileTree.getSingleFile() + //return configurations.zipArchive.asFileTree.getSingleFile() + return fileTree("src/test/resources/job-scheduler").getSingleFile() } } } diff --git a/src/main/java/org/opensearch/ad/ADJobProcessor.java b/src/main/java/org/opensearch/ad/ADJobProcessor.java new file mode 100644 index 000000000..7621f9fc9 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ADJobProcessor.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ADJobProcessor extends + JobProcessor { + + private static ADJobProcessor INSTANCE; + + public static ADJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (JobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ADJobProcessor(); + return INSTANCE; + } + } + + private ADJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.AD, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, AnomalyResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new AnomalyResultRequest(configId, start, end); + } +} diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java index d9d5e3f7b..af5816497 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java @@ -15,7 +15,6 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.util.List; import java.util.Map; @@ -35,10 +34,8 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -49,9 +46,6 @@ import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingRequest; import org.opensearch.ad.transport.RCFPollingResponse; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -68,11 +62,19 @@ import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.metrics.InternalCardinality; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ProfileUtil; import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; public class AnomalyDetectorProfileRunner extends AbstractProfileRunner { @@ -136,11 +138,11 @@ private void calculateTotalResponsesToWait( listener.onFailure(new OpenSearchStatusException(FAIL_TO_PARSE_DETECTOR_MSG + detectorId, BAD_REQUEST)); } } else { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, BAD_REQUEST)); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId, BAD_REQUEST)); } }, exception -> { - logger.error(FAIL_TO_FIND_CONFIG_MSG + detectorId, exception); - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, INTERNAL_SERVER_ERROR)); + logger.error(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId, exception); + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + detectorId, INTERNAL_SERVER_ERROR)); })); } @@ -159,7 +161,7 @@ private void prepareProfile( .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + Job job = Job.parse(parser); long enabledTimeMs = job.getEnabledTime().toEpochMilli(); boolean isMultiEntityDetector = detector.isHighCardinality(); @@ -211,7 +213,7 @@ private void prepareProfile( false ); if (profilesToCollect.contains(DetectorProfileName.ERROR)) { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, adTask -> { + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, adTask -> { DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); if (adTask.isPresent()) { long lastUpdateTimeMs = adTask.get().getLastUpdateTime().toEpochMilli(); @@ -315,6 +317,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener listener, Set profiles) { DetectorProfile.Builder profileBuilder = new DetectorProfile.Builder(); if (profiles.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); + profileBuilder.state(ConfigState.DISABLED); } if (profiles.contains(DetectorProfileName.AD_TASK)) { adTaskManager.getLatestHistoricalTaskProfile(detectorId, transportService, profileBuilder.build(), listener); @@ -409,7 +413,7 @@ private void profileStateRelated( } else { DetectorProfile.Builder builder = new DetectorProfile.Builder(); if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.DISABLED); + builder.state(ConfigState.DISABLED); } listener.onResponse(builder.build()); } @@ -418,7 +422,7 @@ private void profileStateRelated( private void profileModels( AnomalyDetector detector, Set profiles, - AnomalyDetectorJob job, + Job job, boolean forMultiEntityDetector, MultiResponsesDelegateActionListener listener ) { @@ -430,7 +434,7 @@ private void profileModels( private ActionListener onModelResponse( AnomalyDetector detector, Set profilesToCollect, - AnomalyDetectorJob job, + Job job, MultiResponsesDelegateActionListener listener ) { boolean isMultientityDetector = detector.isHighCardinality(); @@ -464,7 +468,7 @@ private ActionListener onModelResponse( } private void profileMultiEntityDetectorStateRelated( - AnomalyDetectorJob job, + Job job, Set profilesToCollect, ProfileResponse profileResponse, DetectorProfile.Builder profileBuilder, @@ -478,10 +482,11 @@ private void profileMultiEntityDetectorStateRelated( long enabledTime = job.getEnabledTime().toEpochMilli(); long totalUpdates = profileResponse.getTotalUpdates(); ProfileUtil - .confirmDetectorRealtimeInitStatus( + .confirmRealtimeInitStatus( detector, enabledTime, client, + AnalysisType.AD, onInittedEver(enabledTime, profileBuilder, profilesToCollect, detector, totalUpdates, listener) ); } else { @@ -490,7 +495,7 @@ private void profileMultiEntityDetectorStateRelated( } } else { if (profilesToCollect.contains(DetectorProfileName.STATE)) { - profileBuilder.state(DetectorState.DISABLED); + profileBuilder.state(ConfigState.DISABLED); } listener.onResponse(profileBuilder.build()); } @@ -577,7 +582,7 @@ private ActionListener onPollRCFUpdates( private void createRunningStateAndInitProgress(Set profilesToCollect, DetectorProfile.Builder builder) { if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.RUNNING).build(); + builder.state(ConfigState.RUNNING).build(); } if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { @@ -595,7 +600,7 @@ private void processInitResponse( MultiResponsesDelegateActionListener listener ) { if (profilesToCollect.contains(DetectorProfileName.STATE)) { - builder.state(DetectorState.INIT); + builder.state(ConfigState.INIT); } if (profilesToCollect.contains(DetectorProfileName.INIT_PROGRESS)) { diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java index 90b7d350f..9c6ce5cde 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java +++ b/src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java @@ -25,19 +25,19 @@ import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.ActionListener; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.EntityAnomalyResult; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.timeseries.constant.CommonValue; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; /** * Runner to trigger an anomaly detector. @@ -45,11 +45,11 @@ public final class AnomalyDetectorRunner { private final Logger logger = LogManager.getLogger(AnomalyDetectorRunner.class); - private final ModelManager modelManager; + private final ADModelManager modelManager; private final FeatureManager featureManager; private final int maxPreviewResults; - public AnomalyDetectorRunner(ModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { + public AnomalyDetectorRunner(ADModelManager modelManager, FeatureManager featureManager, int maxPreviewResults) { this.modelManager = modelManager; this.featureManager = featureManager; this.maxPreviewResults = maxPreviewResults; @@ -168,24 +168,24 @@ private List parsePreviewResult( AnomalyResult result; if (results != null && results.size() > i) { - ThresholdingResult thresholdingResult = results.get(i); - List resultsToSave = thresholdingResult - .toIndexableResults( - detector, - Instant.ofEpochMilli(timeRange.getKey()), - Instant.ofEpochMilli(timeRange.getValue()), - null, - null, - featureDatas, - Optional.ofNullable(entity), - CommonValue.NO_SCHEMA_VERSION, - null, - null, - null + anomalyResults + .addAll( + results + .get(i) + .toIndexableResults( + detector, + Instant.ofEpochMilli(timeRange.getKey()), + Instant.ofEpochMilli(timeRange.getValue()), + null, + null, + featureDatas, + Optional.ofNullable(entity), + CommonValue.NO_SCHEMA_VERSION, + null, + null, + null + ) ); - for (AnomalyResult r : resultsToSave) { - anomalyResults.add(r); - } } else { result = new AnomalyResult( detector.getId(), diff --git a/src/main/java/org/opensearch/ad/EntityProfileRunner.java b/src/main/java/org/opensearch/ad/EntityProfileRunner.java index 491e8088f..479260e21 100644 --- a/src/main/java/org/opensearch/ad/EntityProfileRunner.java +++ b/src/main/java/org/opensearch/ad/EntityProfileRunner.java @@ -28,7 +28,6 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.EntityProfile; import org.opensearch.ad.model.EntityProfileName; @@ -38,8 +37,6 @@ import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileRequest; import org.opensearch.ad.transport.EntityProfileResponse; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.routing.Preference; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -53,11 +50,15 @@ import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; public class EntityProfileRunner extends AbstractProfileRunner { private final Logger logger = LogManager.getLogger(EntityProfileRunner.class); @@ -188,6 +189,7 @@ private void validateEntity( client::search, detector.getId(), client, + AnalysisType.AD, searchResponseListener ); @@ -228,7 +230,7 @@ private void getJob( .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, getResponse.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + Job job = Job.parse(parser); int totalResponsesToWait = 0; if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) @@ -331,7 +333,7 @@ private void profileStateRelated( Entity entityValue, Set profilesToCollect, AnomalyDetector detector, - AnomalyDetectorJob job, + Job job, MultiResponsesDelegateActionListener delegateListener ) { if (totalUpdates == 0) { diff --git a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java index 4b05295ae..f16faac36 100644 --- a/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java +++ b/src/main/java/org/opensearch/ad/ExecuteADResultResponseRecorder.java @@ -11,371 +11,120 @@ package org.opensearch.ad; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; - import java.time.Instant; import java.util.ArrayList; -import java.util.HashSet; import java.util.Optional; -import java.util.Set; -import java.util.concurrent.TimeUnit; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; -import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.ProfileAction; -import org.opensearch.ad.transport.ProfileRequest; -import org.opensearch.ad.transport.RCFPollingAction; -import org.opensearch.ad.transport.RCFPollingRequest; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Client; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.unit.TimeValue; import org.opensearch.commons.authuser.User; -import org.opensearch.search.SearchHits; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -public class ExecuteADResultResponseRecorder { - private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); +public class ExecuteADResultResponseRecorder extends + ExecuteResultResponseRecorder { - private ADIndexManagement anomalyDetectionIndices; - private AnomalyIndexHandler anomalyResultHandler; - private ADTaskManager adTaskManager; - private DiscoveryNodeFilterer nodeFilter; - private ThreadPool threadPool; - private Client client; - private NodeStateManager nodeStateManager; - private ADTaskCacheManager adTaskCacheManager; - private int rcfMinSamples; + private static final Logger log = LogManager.getLogger(ExecuteADResultResponseRecorder.class); public ExecuteADResultResponseRecorder( - ADIndexManagement anomalyDetectionIndices, - AnomalyIndexHandler anomalyResultHandler, - ADTaskManager adTaskManager, + ADIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ADTaskManager taskManager, DiscoveryNodeFilterer nodeFilter, ThreadPool threadPool, Client client, NodeStateManager nodeStateManager, - ADTaskCacheManager adTaskCacheManager, + TaskCacheManager taskCacheManager, int rcfMinSamples ) { - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.anomalyResultHandler = anomalyResultHandler; - this.adTaskManager = adTaskManager; - this.nodeFilter = nodeFilter; - this.threadPool = threadPool; - this.client = client; - this.nodeStateManager = nodeStateManager; - this.adTaskCacheManager = adTaskCacheManager; - this.rcfMinSamples = rcfMinSamples; + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ADIndex.RESULT, + AnalysisType.AD + ); } - public void indexAnomalyResult( - Instant detectionStartTime, - Instant executionStartTime, - AnomalyResultResponse response, - AnomalyDetector detector + @Override + protected AnomalyResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user ) { - String detectorId = detector.getId(); - try { - // skipping writing to the result index if not necessary - // For a single-entity detector, the result is not useful if error is null - // and rcf score (thus anomaly grade/confidence) is null. - // For a HCAD detector, we don't need to save on the detector level. - // We return 0 or Double.NaN rcf score if there is no error. - if ((response.getAnomalyScore() <= 0 || Double.isNaN(response.getAnomalyScore())) && response.getError() == null) { - updateRealtimeTask(response, detectorId); - return; - } - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - if (response.getError() != null) { - log.info("Anomaly result action run successfully for {} with error {}", detectorId, response.getError()); - } - - AnomalyResult anomalyResult = response - .toAnomalyResult( - detectorId, - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - user, - response.getError() - ); - - String resultIndex = detector.getCustomResultIndex(); - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - updateRealtimeTask(response, detectorId); - } catch (EndRunException e) { - throw e; - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); - } + return new AnomalyResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream detectors have no entity + user, + indexManagement.getSchemaVersion(resultIndex), + null // no model id + ); } /** * Update real time task (one document per detector in state index). If the real-time task has no changes compared with local cache, - * the task won't update. Task only updates when the state changed, or any error happened, or AD job stopped. Task is mainly consumed - * by the front-end to track detector status. For single-stream detectors, we embed model total updates in AnomalyResultResponse and - * update state accordingly. For HCAD, we won't wait for model finishing updating before returning a response to the job scheduler + * the task won't update. Task only updates when the state changed, or any error happened, or job stopped. Task is mainly consumed + * by the front-end to track analysis status. For single-stream analyses, we embed model total updates in ResultResponse and + * update state accordingly. For HC analysis, we won't wait for model finishing updating before returning a response to the job scheduler * since it might be long before all entities finish execution. So we don't embed model total updates in AnomalyResultResponse. * Instead, we issue a profile request to poll each model node and get the maximum total updates among all models. * @param response response returned from executing AnomalyResultAction - * @param detectorId Detector Id + * @param configId config Id */ - private void updateRealtimeTask(AnomalyResultResponse response, String detectorId) { - if (response.isHCDetector() != null && response.isHCDetector()) { - if (adTaskManager.skipUpdateHCRealtimeTask(detectorId, response.getError())) { + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (response.isHC() != null && response.isHC()) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { return; } - DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); - Set profiles = new HashSet<>(); - profiles.add(DetectorProfileName.INIT_PROGRESS); - ProfileRequest profileRequest = new ProfileRequest(detectorId, profiles, true, dataNodes); - Runnable profileHCInitProgress = () -> { - client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { - log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates()); - updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - }, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); })); - }; - if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) { - // real time init progress is 0 may mean this is a newly started detector - // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to - // finish. We can change the detector running immediately instead of waiting for the next interval. - threadPool - .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - profileHCInitProgress.run(); - } - + delayedUpdate(response, configId); } else { log .debug( "Update latest realtime task for single stream detector {}, total updates: {}", - detectorId, + configId, response.getRcfTotalUpdates() ); - updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError()); - } - } - - private void updateLatestRealtimeTask( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error - ) { - // Don't need info as this will be printed repeatedly in each interval - ActionListener listener = ActionListener.wrap(r -> { - if (r != null) { - log.debug("Updated latest realtime task successfully for detector {}, taskState: {}", detectorId, taskState); - } - }, e -> { - if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CAN_NOT_FIND_LATEST_TASK)) { - // Clear realtime task cache, will recreate AD task in next run, check AnomalyResultTransportAction. - log.error("Can't find latest realtime task of detector " + detectorId); - adTaskManager.removeRealtimeTaskCache(detectorId); - } else { - log.error("Failed to update latest realtime task for detector " + detectorId, e); - } - }); - - // rcfTotalUpdates is null when we save exception messages - if (!adTaskCacheManager.hasQueriedResultIndex(detectorId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { - // confirm the total updates number since it is possible that we have already had results after job enabling time - // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. - confirmTotalRCFUpdatesFound( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - ActionListener - .wrap( - r -> adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - r, - detectorIntervalInMinutes, - error, - listener - ), - e -> { - log.error("Fail to confirm rcf update", e); - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - ) - ); - } else { - adTaskManager - .updateLatestRealtimeTaskOnCoordinatingNode( - detectorId, - taskState, - rcfTotalUpdates, - detectorIntervalInMinutes, - error, - listener - ); - } - } - - /** - * The function is not only indexing the result with the exception, but also updating the task state after - * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. - * - * @param detectionStartTime execution start time - * @param executionStartTime execution end time - * @param errorMessage Error message to record - * @param taskState AD task state (e.g., stopped) - * @param detector Detector config accessor - */ - public void indexAnomalyResultException( - Instant detectionStartTime, - Instant executionStartTime, - String errorMessage, - String taskState, - AnomalyDetector detector - ) { - String detectorId = detector.getId(); - try { - IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay(); - Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); - User user = detector.getUser(); - - AnomalyResult anomalyResult = new AnomalyResult( - detectorId, - null, // no task id - new ArrayList(), - dataStartTime, - dataEndTime, - executionStartTime, - Instant.now(), - errorMessage, - Optional.empty(), // single-stream detectors have no entity - user, - anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT), - null // no model id + updateLatestRealtimeTask( + configId, + null, + response.getRcfTotalUpdates(), + response.getConfigIntervalInMinutes(), + response.getError() ); - String resultIndex = detector.getCustomResultIndex(); - if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) { - // Set result index as null, will write exception to default result index. - anomalyResultHandler.index(anomalyResult, detectorId, null); - } else { - anomalyResultHandler.index(anomalyResult, detectorId, resultIndex); - } - - if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHighCardinality()) { - // single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG - // when there is no checkpoint. - // Delay real time cache update by one minute so we will have trained models by then and update the state - // document accordingly. - threadPool.schedule(() -> { - RCFPollingRequest request = new RCFPollingRequest(detectorId); - client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { - long totalUpdates = rcfPollResponse.getTotalUpdates(); - // if there are updates, don't record failures - updateLatestRealtimeTask( - detectorId, - taskState, - totalUpdates, - detector.getIntervalInMinutes(), - totalUpdates > 0 ? "" : errorMessage - ); - }, e -> { - log.error("Fail to execute RCFRollingAction", e); - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - })); - }, new TimeValue(60, TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); - } else { - updateLatestRealtimeTask(detectorId, taskState, null, null, errorMessage); - } - - } catch (Exception e) { - log.error("Failed to index anomaly result for " + detectorId, e); } } - - private void confirmTotalRCFUpdatesFound( - String detectorId, - String taskState, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")); - return; - } - nodeStateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(jobOptional -> { - if (!jobOptional.isPresent()) { - listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")); - return; - } - - ProfileUtil - .confirmDetectorRealtimeInitStatus( - detectorOptional.get(), - jobOptional.get().getEnabledTime().toEpochMilli(), - client, - ActionListener.wrap(searchResponse -> { - ActionListener.completeWith(listener, () -> { - SearchHits hits = searchResponse.getHits(); - Long correctedTotalUpdates = rcfTotalUpdates; - if (hits.getTotalHits().value > 0L) { - // correct the number if we have already had results after job enabling time - // so that the detector won't stay initialized - correctedTotalUpdates = Long.valueOf(rcfMinSamples); - } - adTaskCacheManager.markResultIndexQueried(detectorId); - return correctedTotalUpdates; - }); - }, exception -> { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - // anomaly result index is not created yet - adTaskCacheManager.markResultIndexQueried(detectorId); - listener.onResponse(0L); - } else { - listener.onFailure(exception); - } - }) - ); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get job")))); - }, e -> listener.onFailure(new TimeSeriesException(detectorId, "fail to get detector")))); - } } diff --git a/src/main/java/org/opensearch/ad/ProfileUtil.java b/src/main/java/org/opensearch/ad/ProfileUtil.java deleted file mode 100644 index 8afd98dc3..000000000 --- a/src/main/java/org/opensearch/ad/ProfileUtil.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import org.opensearch.action.ActionListener; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.client.Client; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.ExistsQueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.timeseries.constant.CommonName; - -public class ProfileUtil { - /** - * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. - * Note this function is only meant to check for status of real time analysis. - * - * @param detectorId detector id - * @param enabledTime the time when AD job is enabled in milliseconds - * @return the search request - */ - private static SearchRequest createRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { - BoolQueryBuilder filterQuery = new BoolQueryBuilder(); - filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); - filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); - filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); - // Historical analysis result also stored in result index, which has non-null task_id. - // For realtime detection result, we should filter task_id == null - ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); - filterQuery.mustNot(taskIdExistsFilter); - - SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); - - SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); - request.source(source); - if (resultIndex != null) { - request.indices(resultIndex); - } - return request; - } - - public static void confirmDetectorRealtimeInitStatus( - AnomalyDetector detector, - long enabledTime, - Client client, - ActionListener listener - ) { - SearchRequest searchLatestResult = createRealtimeInittedEverRequest(detector.getId(), enabledTime, detector.getCustomResultIndex()); - client.search(searchLatestResult, listener); - } -} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java new file mode 100644 index 000000000..828146516 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheBuffer.java @@ -0,0 +1,75 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * We use a layered cache to manage active entities’ states. We have a two-level + * cache that stores active entity states in each node. Each detector has its + * dedicated cache that stores ten (dynamically adjustable) entities’ states per + * node. A detector’s hottest entities load their states in the dedicated cache. + * If less than 10 entities use the dedicated cache, the secondary cache can use + * the rest of the free memory available to AD. The secondary cache is a shared + * memory among all detectors for the long tail. The shared cache size is 10% + * heap minus all of the dedicated cache consumed by single-entity and multi-entity + * detectors. The shared cache’s size shrinks as the dedicated cache is filled + * up or more detectors are started. + * + * Implementation-wise, both dedicated cache and shared cache are stored in items + * and minimumCapacity controls the boundary. If items size is equals to or less + * than minimumCapacity, consider items as dedicated cache; otherwise, consider + * top minimumCapacity active entities (last X entities in priorityList) as in dedicated + * cache and all others in shared cache. + */ +public class ADCacheBuffer extends + CacheBuffer { + + public ADCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + minimumCapacity, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_DETECTOR + ); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java new file mode 100644 index 000000000..e71c89962 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADCacheProvider.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Allows Guice dependency based on types. Otherwise, Guice cannot + * decide which instance to inject based on generic types of CacheProvider + * + */ +public class ADCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java new file mode 100644 index 000000000..95decc8e8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/caching/ADPriorityCache.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.caching; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; +import java.util.concurrent.Callable; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.settings.ADEnabledSetting; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADPriorityCache extends + PriorityCache { + private ADCheckpointWriteWorker checkpointWriteQueue; + private ADCheckpointMaintainWorker checkpointMaintainQueue; + + public ADPriorityCache( + ADCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + String threadPoolName, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ADCheckpointWriteWorker checkpointWriteQueue, + ADCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + threadPoolName, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_DETECTOR, + AD_DEDICATED_CACHE_SIZE, + AD_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ADCacheBuffer createEmptyCacheBuffer(Config detector, long memoryConsumptionPerEntity) { + return new ADCacheBuffer( + detector.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + memoryConsumptionPerEntity, + checkpointWriteQueue, + checkpointMaintainQueue, + detector.getId(), + detector.getIntervalInSeconds() + ); + } + + @Override + protected Callable> createInactiveEntityCacheLoader(String modelId, String detectorId) { + return new Callable>() { + @Override + public ModelState call() { + return new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + null, + Optional.empty(), + new ArrayDeque<>() + ); + } + }; + } + + @Override + protected boolean isDoorKeeperInCacheEnabled() { + return ADEnabledSetting.isDoorKeeperInCacheEnabled(); + } +} diff --git a/src/main/java/org/opensearch/ad/caching/CacheProvider.java b/src/main/java/org/opensearch/ad/caching/CacheProvider.java deleted file mode 100644 index ab8fd191c..000000000 --- a/src/main/java/org/opensearch/ad/caching/CacheProvider.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import org.opensearch.common.inject.Provider; - -/** - * A wrapper to call concrete implementation of caching. Used in transport - * action. Don't use interface because transport action handler constructor - * requires a concrete class as input. - * - */ -public class CacheProvider implements Provider { - private EntityCache cache; - - public CacheProvider() { - - } - - @Override - public EntityCache get() { - return cache; - } - - public void set(EntityCache cache) { - this.cache = cache; - } -} diff --git a/src/main/java/org/opensearch/ad/caching/EntityCache.java b/src/main/java/org/opensearch/ad/caching/EntityCache.java deleted file mode 100644 index 0a6a303d6..000000000 --- a/src/main/java/org/opensearch/ad/caching/EntityCache.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.caching; - -import java.util.Collection; -import java.util.List; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.ad.CleanState; -import org.opensearch.ad.DetectorModelSize; -import org.opensearch.ad.MaintenanceState; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.timeseries.model.Entity; - -public interface EntityCache extends MaintenanceState, CleanState, DetectorModelSize { - /** - * Get the ModelState associated with the entity. May or may not load the - * ModelState depending on the underlying cache's eviction policy. - * - * @param modelId Model Id - * @param detector Detector config object - * @return the ModelState associated with the model or null if no cached item - * for the entity - */ - ModelState get(String modelId, AnomalyDetector detector); - - /** - * Get the number of active entities of a detector - * @param detector Detector Id - * @return The number of active entities - */ - int getActiveEntities(String detector); - - /** - * - * @return total active entities in the cache - */ - int getTotalActiveEntities(); - - /** - * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity model Id - * @return Whether an entity is active or not - */ - boolean isActive(String detectorId, String entityModelId); - - /** - * Get total updates of detector's most active entity's RCF model. - * - * @param detectorId detector id - * @return RCF model total updates of most active entity. - */ - long getTotalUpdates(String detectorId); - - /** - * Get RCF model total updates of specific entity - * - * @param detectorId detector id - * @param entityModelId entity model id - * @return RCF model total updates of specific entity. - */ - long getTotalUpdates(String detectorId, String entityModelId); - - /** - * Gets modelStates of all model hosted on a node - * - * @return list of modelStates - */ - List> getAllModels(); - - /** - * Return when the last active time of an entity's state. - * - * If the entity's state is active in the cache, the value indicates when the cache - * is lastly accessed (get/put). If the entity's state is inactive in the cache, - * the value indicates when the cache state is created or when the entity is evicted - * from active entity cache. - * - * @param detectorId The Id of the detector that an entity belongs to - * @param entityModelId Entity's Model Id - * @return if the entity is in the cache, return the timestamp in epoch - * milliseconds when the entity's state is lastly used. Otherwise, return -1. - */ - long getLastActiveMs(String detectorId, String entityModelId); - - /** - * Release memory when memory circuit breaker is open - */ - void releaseMemoryForOpenCircuitBreaker(); - - /** - * Select candidate entities for which we can load models - * @param cacheMissEntities Cache miss entities - * @param detectorId Detector Id - * @param detector Detector object - * @return A list of entities that are admitted into the cache as a result of the - * update and the left-over entities - */ - Pair, List> selectUpdateCandidate( - Collection cacheMissEntities, - String detectorId, - AnomalyDetector detector - ); - - /** - * - * @param detector Detector config - * @param toUpdate Model state candidate - * @return if we can host the given model state - */ - boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate); - - /** - * - * @param detectorId Detector Id - * @return a detector's model information - */ - List getAllModelProfile(String detectorId); - - /** - * Gets an entity's model sizes - * - * @param detectorId Detector Id - * @param entityModelId Entity's model Id - * @return the entity's memory size - */ - Optional getModelProfile(String detectorId, String entityModelId); - - /** - * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id - * @param modelId Model Id - * @return Model state - */ - Optional> getForMaintainance(String detectorId, String modelId); - - /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id - */ - void removeEntityModel(String detectorId, String entityModelId); -} diff --git a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java b/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java index 325361aec..3f8fe461e 100644 --- a/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java +++ b/src/main/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanup.java @@ -22,7 +22,6 @@ import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; import org.opensearch.action.admin.indices.stats.ShardStats; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; @@ -30,6 +29,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.index.store.StoreStats; +import org.opensearch.timeseries.util.ClientUtil; /** * Clean up the old docs for indices. diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java index e20dc8fd1..091a24bd1 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonMessages.java @@ -43,7 +43,6 @@ public class ADCommonMessages { public static String EXCEED_HISTORICAL_ANALYSIS_LIMIT = "Exceed max historical analysis limit per node"; public static String NO_ELIGIBLE_NODE_TO_RUN_DETECTOR = "No eligible node to run detector "; public static String EMPTY_STALE_RUNNING_ENTITIES = "Empty stale running entities"; - public static String CAN_NOT_FIND_LATEST_TASK = "can't find latest task"; public static String NO_ENTITY_FOUND = "No entity found"; public static String HISTORICAL_ANALYSIS_CANCELLED = "Historical analysis cancelled by user"; public static String HC_DETECTOR_TASK_IS_UPDATING = "HC detector task is updating"; diff --git a/src/main/java/org/opensearch/ad/constant/ADCommonName.java b/src/main/java/org/opensearch/ad/constant/ADCommonName.java index 3a97db889..55a2a58be 100644 --- a/src/main/java/org/opensearch/ad/constant/ADCommonName.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonName.java @@ -59,7 +59,6 @@ public class ADCommonName { public static final String MODELS = "models"; public static final String MODEL = "model"; public static final String INIT_PROGRESS = "init_progress"; - public static final String CATEGORICAL_FIELD = "category_field"; public static final String TOTAL_ENTITIES = "total_entities"; public static final String ACTIVE_ENTITIES = "active_entities"; public static final String ENTITY_INFO = "entity_info"; @@ -87,11 +86,8 @@ public class ADCommonName { public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String QUEUE_JSON_KEY = "queue"; - // ====================================== - // Used for backward-compatibility in messaging - // ====================================== - public static final String EMPTY_FIELD = ""; + // ====================================== // Validation // ====================================== // detector validation aspect diff --git a/src/main/java/org/opensearch/ad/constant/CommonValue.java b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java similarity index 81% rename from src/main/java/org/opensearch/ad/constant/CommonValue.java rename to src/main/java/org/opensearch/ad/constant/ADCommonValue.java index f5d5b15eb..91b9f72f7 100644 --- a/src/main/java/org/opensearch/ad/constant/CommonValue.java +++ b/src/main/java/org/opensearch/ad/constant/ADCommonValue.java @@ -11,9 +11,7 @@ package org.opensearch.ad.constant; -public class CommonValue { - // unknown or no schema version - public static Integer NO_SCHEMA_VERSION = 0; +public class ADCommonValue { public static String INTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/adinternal/"; public static String EXTERNAL_ACTION_PREFIX = "cluster:admin/opendistro/ad/"; } diff --git a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java similarity index 60% rename from src/main/java/org/opensearch/ad/ml/CheckpointDao.java rename to src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java index 738acd197..a261cc979 100644 --- a/src/main/java/org/opensearch/ad/ml/CheckpointDao.java +++ b/src/main/java/org/opensearch/ad/ml/ADCheckpointDao.java @@ -11,64 +11,52 @@ package org.opensearch.ad.ml; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; import java.time.Clock; -import java.time.Duration; import java.time.Instant; import java.time.ZoneOffset; import java.time.ZonedDateTime; -import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Arrays; import java.util.Base64; +import java.util.Deque; import java.util.HashMap; import java.util.List; -import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.ActionListener; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.bulk.BulkResponse; -import org.opensearch.action.delete.DeleteRequest; -import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.get.MultiGetAction; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.update.UpdateRequest; -import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.MatchQueryBuilder; -import org.opensearch.index.reindex.BulkByScrollResponse; -import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.ScrollableHitSource; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.ClientUtil; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.config.Precision; @@ -89,29 +77,18 @@ /** * DAO for model checkpoints. */ -public class CheckpointDao { - - private static final Logger logger = LogManager.getLogger(CheckpointDao.class); - static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; - static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; - static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; - static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; - static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; - static final String NOT_ABLE_TO_DELETE_LOG_MSG = "Cannot delete all checkpoints of detector"; +public class ADCheckpointDao extends CheckpointDao { + private static final Logger logger = LogManager.getLogger(ADCheckpointDao.class); + // ====================================== + // Model serialization/deserialization + // ====================================== public static final String ENTITY_RCF = "rcf"; public static final String ENTITY_THRESHOLD = "th"; public static final String ENTITY_TRCF = "trcf"; public static final String FIELD_MODELV2 = "modelV2"; public static final String DETECTOR_ID = "detectorId"; - // dependencies - private final Client client; - private final ClientUtil clientUtil; - - // configuration - private final String indexName; - private Gson gson; private RandomCutForestMapper mapper; @@ -130,11 +107,7 @@ public class CheckpointDao { private final ADIndexManagement indexUtil; private final JsonParser parser = new JsonParser(); - // we won't read/write a checkpoint larger than a threshold - private final int maxCheckpointBytes; - private final GenericObjectPool serializeRCFBufferPool; - private final int serializeRCFBufferSize; // anomaly rate private double anomalyRate; @@ -156,10 +129,9 @@ public class CheckpointDao { * @param serializeRCFBufferSize the size of the buffer for RCF serialization * @param anomalyRate anomaly rate */ - public CheckpointDao( + public ADCheckpointDao( Client client, ClientUtil clientUtil, - String indexName, Gson gson, RandomCutForestMapper mapper, V1JsonToV3StateConverter converter, @@ -170,36 +142,29 @@ public CheckpointDao( int maxCheckpointBytes, GenericObjectPool serializeRCFBufferPool, int serializeRCFBufferSize, - double anomalyRate + double anomalyRate, + Clock clock ) { - this.client = client; - this.clientUtil = clientUtil; - this.indexName = indexName; - this.gson = gson; + super( + client, + clientUtil, + ADCommonName.CHECKPOINT_INDEX_NAME, + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); this.mapper = mapper; this.converter = converter; this.trcfMapper = trcfMapper; this.trcfSchema = trcfSchema; this.thresholdingModelClass = thresholdingModelClass; this.indexUtil = indexUtil; - this.maxCheckpointBytes = maxCheckpointBytes; - this.serializeRCFBufferPool = serializeRCFBufferPool; - this.serializeRCFBufferSize = serializeRCFBufferSize; this.anomalyRate = anomalyRate; } - private void saveModelCheckpointSync(Map source, String modelId) { - clientUtil.timedRequest(new IndexRequest(indexName).id(modelId).source(source), logger, client::index); - } - - private void putModelCheckpoint(String modelId, Map source, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - onCheckpointNotExist(source, modelId, true, listener); - } - } - /** * Puts a rcf model checkpoint in the storage. * @@ -234,66 +199,25 @@ public void putThresholdCheckpoint(String modelId, ThresholdingModel threshold, putModelCheckpoint(modelId, source, listener); } - private void onCheckpointNotExist(Map source, String modelId, boolean isAsync, ActionListener listener) { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - if (isAsync) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - saveModelCheckpointSync(source, modelId); - } - } else { - throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - if (isAsync) { - saveModelCheckpointAsync(source, modelId, listener); - } else { - saveModelCheckpointSync(source, modelId); - } - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); - } - })); - } - - /** - * Update the model doc using fields in source. This ensures we won't touch - * the old checkpoint and nodes with old/new logic can coexist in a cluster. - * This is useful for introducing compact rcf new model format. - * - * @param source fields to update - * @param modelId model Id, used as doc id in the checkpoint index - * @param listener Listener to return response - */ - private void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { - - UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); - updateRequest.doc(source); - // If the document does not already exist, the contents of the upsert element are inserted as a new document. - // If the document exists, update fields in the map - updateRequest.docAsUpsert(true); - clientUtil - .asyncRequest( - updateRequest, - client::update, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - /** * Prepare for index request using the contents of the given model state * @param modelState an entity model state * @return serialized JSON map or empty map if the state is too bloated * @throws IOException when serialization fails */ - public Map toIndexSource(ModelState modelState) throws IOException { + @Override + public Map toIndexSource(ModelState modelState) throws IOException { String modelId = modelState.getModelId(); Map source = new HashMap<>(); - EntityModel model = modelState.getModel(); - Optional serializedModel = toCheckpoint(model, modelId); + + Object model = modelState.getModel(); + if (modelState.getEntity().isEmpty()) { + throw new IllegalArgumentException("Excpect model state to be an entity model"); + } + + ThresholdedRandomCutForest entityModel = (ThresholdedRandomCutForest) model; + + Optional serializedModel = toCheckpoint(entityModel, modelId); if (!serializedModel.isPresent() || serializedModel.get().length() > maxCheckpointBytes) { logger .warn( @@ -305,13 +229,25 @@ public Map toIndexSource(ModelState modelState) thr ); return source; } - String detectorId = modelState.getId(); + source.put(FIELD_MODELV2, serializedModel.get()); + + if (modelState.getSamples() != null && !(modelState.getSamples().isEmpty())) { + source.put(CommonName.ENTITY_SAMPLE_QUEUE, toCheckpoint(modelState.getSamples()).get()); + } + + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.ENTITY_SAMPLE_QUEUE) && !source.containsKey(FIELD_MODELV2)) { + return source; + } + + String detectorId = modelState.getConfigId(); source.put(DETECTOR_ID, detectorId); // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value - source.put(FIELD_MODELV2, serializedModel.get()); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); - source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); - Optional entity = model.getEntity(); + source.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ADIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); if (entity.isPresent()) { source.put(CommonName.ENTITY_KEY, entity.get()); } @@ -325,7 +261,7 @@ public Map toIndexSource(ModelState modelState) thr * @param modelId model id * @return serialized string */ - public Optional toCheckpoint(EntityModel model, String modelId) { + public Optional toCheckpoint(ThresholdedRandomCutForest model, String modelId) { return AccessController.doPrivileged((PrivilegedAction>) () -> { if (model == null) { logger.warn("Empty model"); @@ -333,11 +269,8 @@ public Optional toCheckpoint(EntityModel model, String modelId) { } try { JsonObject json = new JsonObject(); - if (model.getSamples() != null && !(model.getSamples().isEmpty())) { - json.add(CommonName.ENTITY_SAMPLE, gson.toJsonTree(model.getSamples())); - } - if (model.getTrcf().isPresent()) { - json.addProperty(ENTITY_TRCF, toCheckpoint(model.getTrcf().get())); + if (model != null) { + json.addProperty(ENTITY_TRCF, toCheckpoint(model)); } // if json is empty, it will be an empty Json string {}. No need to save it on disk. return json.entrySet().isEmpty() ? Optional.empty() : Optional.ofNullable(gson.toJson(json)); @@ -382,21 +315,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf) { return checkpoint; } - private Map.Entry checkoutOrNewBuffer() { - LinkedBuffer buffer = null; - boolean isCheckout = true; - try { - buffer = serializeRCFBufferPool.borrowObject(); - } catch (Exception e) { - logger.warn("Failed to borrow a buffer from pool", e); - } - if (buffer == null) { - buffer = LinkedBuffer.allocate(serializeRCFBufferSize); - isCheckout = false; - } - return new SimpleImmutableEntry(buffer, isCheckout); - } - private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer) { try { byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { @@ -409,73 +327,6 @@ private String toCheckpoint(ThresholdedRandomCutForest trcf, LinkedBuffer buffer } } - /** - * Deletes the model checkpoint for the model. - * - * @param modelId id of the model - * @param listener onReponse is called with null when the operation is completed - */ - public void deleteModelCheckpoint(String modelId, ActionListener listener) { - clientUtil - .asyncRequest( - new DeleteRequest(indexName, modelId), - client::delete, - ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) - ); - } - - /** - * Delete checkpoints associated with a detector. Used in multi-entity detector. - * @param detectorID Detector Id - */ - public void deleteModelCheckpointByDetectorId(String detectorID) { - // A bulk delete request is performed for each batch of matching documents. If a - // search or bulk request is rejected, the requests are retried up to 10 times, - // with exponential back off. If the maximum retry limit is reached, processing - // halts and all failed requests are returned in the response. Any delete - // requests that completed successfully still stick, they are not rolled back. - DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(ADCommonName.CHECKPOINT_INDEX_NAME) - .setQuery(new MatchQueryBuilder(DETECTOR_ID, detectorID)) - .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) - .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. - // Retry in this case - .setRequestsPerSecond(500); // throttle delete requests - logger.info("Delete checkpoints of detector {}", detectorID); - client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { - if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { - logFailure(response, detectorID); - } - // can return 0 docs get deleted because: - // 1) we cannot find matching docs - // 2) bad stats from OpenSearch. In this case, docs are deleted, but - // OpenSearch says deleted is 0. - logger.info("{} " + DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); - }, exception -> { - if (exception instanceof IndexNotFoundException) { - logger.info(INDEX_DELETED_LOG_MSG + " {}", detectorID); - } else { - // Gonna eventually delete in daily cron. - logger.error(NOT_ABLE_TO_DELETE_LOG_MSG, exception); - } - })); - } - - private void logFailure(BulkByScrollResponse response, String detectorID) { - if (response.isTimedOut()) { - logger.warn(TIMEOUT_LOG_MSG + " {}", detectorID); - } else if (!response.getBulkFailures().isEmpty()) { - logger.warn(BULK_FAILURE_LOG_MSG + " {}", detectorID); - for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { - logger.warn(bulkFailure); - } - } else { - logger.warn(SEARCH_FAILURE_LOG_MSG + " {}", detectorID); - for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { - logger.warn(searchFailure); - } - } - } - /** * Load json checkpoint into models * @@ -484,9 +335,14 @@ private void logFailure(BulkByScrollResponse response, String detectorID) { * @return a pair of entity model and its last checkpoint time; or empty if * the raw checkpoint is too large */ - public Optional> fromEntityModelCheckpoint(Map checkpoint, String modelId) { + @Override + protected ModelState fromEntityModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { try { - return AccessController.doPrivileged((PrivilegedAction>>) () -> { + return AccessController.doPrivileged((PrivilegedAction>) () -> { Object modelObj = checkpoint.get(FIELD_MODELV2); if (modelObj == null) { // in case there is old -format checkpoint @@ -494,24 +350,14 @@ public Optional> fromEntityModelCheckpoint(Map maxCheckpointBytes) { logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length())); - return Optional.empty(); + return null; } JsonObject json = parser.parse(model).getAsJsonObject(); - ArrayDeque samples = null; - if (json.has(CommonName.ENTITY_SAMPLE)) { - // verified, don't need privileged call to get permission - samples = new ArrayDeque<>( - Arrays.asList(this.gson.fromJson(json.getAsJsonArray(CommonName.ENTITY_SAMPLE), new double[0][0].getClass())) - ); - } else { - // avoid possible null pointer exception - samples = new ArrayDeque<>(); - } ThresholdedRandomCutForest trcf = null; if (json.has(ENTITY_TRCF)) { @@ -540,6 +386,25 @@ public Optional> fromEntityModelCheckpoint(Map sampleQueue = new ArrayDeque<>(); + Object samples = checkpoint.get(CommonName.ENTITY_SAMPLE_QUEUE); + if (samples != null) { + try ( + XContentParser sampleParser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, (String) samples) + ) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, sampleParser.currentToken(), sampleParser); + while (sampleParser.nextToken() != XContentParser.Token.END_ARRAY) { + sampleQueue.add(Sample.parse(sampleParser)); + } + } catch (Exception e) { + logger.warn("Exception while deserializing samples for " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); Instant timestamp = Instant.parse(lastCheckpointTimeString); Entity entity = null; @@ -551,14 +416,27 @@ public Optional> fromEntityModelCheckpoint(Map(entityModel, timestamp)); + + ModelState modelState = new ModelState( + trcf, + modelId, + configId, + ModelManager.ModelType.TRCF.getName(), + clock, + 0, + // TODO: track last processed sample in AD + new Sample(), + Optional.ofNullable(entity), + sampleQueue + ); + modelState.setLastCheckpointTime(timestamp); + return modelState; }); } catch (Exception e) { logger.warn("Exception while deserializing checkpoint " + modelId, e); // checkpoint corrupted (e.g., a checkpoint not recognized by current code // due to bugs). Better redo training. - return Optional.empty(); + return null; } } @@ -634,33 +512,14 @@ private void deserializeTRCFModel( } } - /** - * Read a checkpoint from the index and return the EntityModel object - * @param modelId Model Id - * @param listener Listener to return a pair of entity model and its last checkpoint time - */ - public void deserializeModelCheckpoint(String modelId, ActionListener>> listener) { - clientUtil - .asyncRequest( - new GetRequest(indexName, modelId), - client::get, - ActionListener.wrap(response -> { listener.onResponse(processGetResponse(response, modelId)); }, listener::onFailure) - ); - } - - /** - * Process a checkpoint GetResponse and return the EntityModel object - * @param response Checkpoint Index GetResponse - * @param modelId Model Id - * @return a pair of entity model and its last checkpoint time - */ - public Optional> processGetResponse(GetResponse response, String modelId) { - Optional> checkpointString = processRawCheckpoint(response); - if (checkpointString.isPresent()) { - return fromEntityModelCheckpoint(checkpointString.get(), modelId); - } else { - return Optional.empty(); - } + @Override + protected ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ) { + // single stream AD code path is still using old way + throw new UnsupportedOperationException("This method is not supported"); } /** @@ -730,37 +589,6 @@ private Optional processThresholdModelCheckpoint(GetResponse response) { .map(source -> source.get(CommonName.FIELD_MODEL)); } - private Optional> processRawCheckpoint(GetResponse response) { - return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); - } - - public void batchRead(MultiGetRequest request, ActionListener listener) { - clientUtil.execute(MultiGetAction.INSTANCE, request, listener); - } - - public void batchWrite(BulkRequest request, ActionListener listener) { - if (indexUtil.doesCheckpointIndexExist()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - // create index failure. Notify callers using listener. - listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - clientUtil.execute(BulkAction.INSTANCE, request, listener); - } else { - logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); - listener.onFailure(exception); - } - })); - } - } - private Optional convertToTRCF(Optional rcf, Optional kllThreshold) { if (!rcf.isPresent()) { return Optional.empty(); @@ -774,17 +602,13 @@ private Optional convertToTRCF(Optional { + + /** + * Constructor + * + * @param clock UTC clock + * @param threadPool Accessor to different threadpools + * @param nodeStateManager Storing node state + * @param rcfSampleSize The sample size used by stream samplers in this forest + * @param numberOfTrees The number of trees in this forest. + * @param rcfTimeDecay rcf samples time decay constant + * @param numMinSamples The number of points required by stream samplers before + * results are returned. + * @param defaultSampleStride default sample distances measured in detector intervals. + * @param defaultTrainSamples Default train samples to collect. + * @param interpolator Used to generate data points between samples. + * @param searchFeatureDao Used to issue ES queries. + * @param thresholdMinPvalue min P-value for thresholding + * @param featureManager Used to create features for models. + * @param modelTtl time-to-live before last access time of the cold start cache. + * We have a cache to record entities that have run cold starts to avoid + * repeated unsuccessful cold start. + * @param checkpointWriteWorker queue to insert model checkpoints + * @param rcfSeed rcf random seed + * @param maxRoundofColdStart max number of rounds of cold start + * @param cool down minutes when OpenSearch is overloaded + */ + public ADEntityColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteWorker, + long rcfSeed, + int maxRoundofColdStart, + int coolDownMinutes + ) { + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + rcfTimeDecay, + nodeStateManager, + defaultSampleStride, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + AnalysisType.AD + ); + } + + public ADEntityColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + int maxSampleStride, + int maxTrainSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + int maxRoundofColdStart, + int coolDownMinutes + ) { + this( + clock, + threadPool, + nodeStateManager, + rcfSampleSize, + numberOfTrees, + rcfTimeDecay, + numMinSamples, + maxSampleStride, + maxTrainSamples, + searchFeatureDao, + thresholdMinPvalue, + featureManager, + modelTtl, + checkpointWriteQueue, + -1, + maxRoundofColdStart, + coolDownMinutes + ); + } + + /** + * Train model using given data points and save the trained model. + * + * @param pointSamplePair A pair consisting of a queue of continuous data points, + * in ascending order of timestamps and last seen sample. + * @param entity Entity instance + * @param entityState Entity state associated with the model Id + */ + @Override + protected void trainModelFromDataSegments( + Pair pointSamplePair, + Optional entity, + ModelState entityState, + Config config + ) { + if (entity.isEmpty()) { + throw new IllegalArgumentException("We offer only HC cold start"); + } + + double[][] dataPoints = pointSamplePair.getKey(); + if (dataPoints == null || dataPoints.length == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + + double[] firstPoint = dataPoints[0]; + if (firstPoint == null || firstPoint.length == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + int shingleSize = config.getShingleSize(); + int dimensions = firstPoint.length * shingleSize; + ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest + .builder() + .dimensions(dimensions) + .sampleSize(rcfSampleSize) + .numberOfTrees(numberOfTrees) + .timeDecay(rcfTimeDecay) + .outputAfter(numMinSamples) + .initialAcceptFraction(initialAcceptFraction) + .parallelExecutionEnabled(false) + .compact(true) + .precision(Precision.FLOAT_32) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // same with dimension for opportunistic memory saving + // Usually, we use it as shingleSize(dimension). When a new point comes in, we will + // look at the point store if there is any overlapping. Say the previously-stored + // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize + // overlapping x3, x4, and only store x5, x6. + .shingleSize(shingleSize) + .internalShinglingEnabled(true) + .anomalyRate(1 - this.thresholdMinPvalue); + + if (rcfSeed > 0) { + rcfBuilder.randomSeed(rcfSeed); + } + ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); + + for (int i = 0; i < dataPoints.length; i++) { + trcf.process(dataPoints[i], 0); + } + + entityState.setModel(trcf); + + entityState.setLastUsedTime(clock.instant()); + entityState.setLastProcessedSample(pointSamplePair.getValue()); + + // save to checkpoint + checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); + } + + @Override + protected boolean isInterpolationInColdStartEnabled() { + return ADEnabledSetting.isInterpolationInColdStartEnabled(); + } +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelManager.java b/src/main/java/org/opensearch/ad/ml/ADModelManager.java similarity index 71% rename from src/main/java/org/opensearch/ad/ml/ModelManager.java rename to src/main/java/org/opensearch/ad/ml/ADModelManager.java index 464297193..3257e01b4 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelManager.java +++ b/src/main/java/org/opensearch/ad/ml/ADModelManager.java @@ -23,7 +23,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; @@ -31,22 +30,28 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; -import org.opensearch.ad.DetectorModelSize; -import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.DateUtils; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.AnalysisModelSize; +import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.MemoryAwareConcurrentHashmap; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.config.Precision; @@ -56,51 +61,27 @@ /** * A facade managing ML operations and models. */ -public class ModelManager implements DetectorModelSize { +public class ADModelManager extends + ModelManager + implements + AnalysisModelSize { protected static final String ENTITY_SAMPLE = "sp"; protected static final String ENTITY_RCF = "rcf"; protected static final String ENTITY_THRESHOLD = "th"; - public enum ModelType { - RCF("rcf"), - THRESHOLD("threshold"), - ENTITY("entity"); - - private String name; - - ModelType(String name) { - this.name = name; - } - - public String getName() { - return name; - } - } - - private static final Logger logger = LogManager.getLogger(ModelManager.class); + private static final Logger logger = LogManager.getLogger(ADModelManager.class); // states - private TRCFMemoryAwareConcurrentHashmap forests; + private MemoryAwareConcurrentHashmap forests; private Map> thresholds; // configuration - private final int rcfNumTrees; - private final int rcfNumSamplesInTree; - private final double rcfTimeDecay; - private final int rcfNumMinSamples; + private final double thresholdMinPvalue; private final int minPreviewSize; private final Duration modelTtl; private Duration checkpointInterval; - // dependencies - private final CheckpointDao checkpointDao; - private final Clock clock; - public FeatureManager featureManager; - - private EntityColdStarter entityColdStarter; - private MemoryTracker memoryTracker; - private final double initialAcceptFraction; /** @@ -122,8 +103,8 @@ public String getName() { * @param settings Node settings * @param clusterService Cluster service accessor */ - public ModelManager( - CheckpointDao checkpointDao, + public ADModelManager( + ADCheckpointDao checkpointDao, Clock clock, int rcfNumTrees, int rcfNumSamplesInTree, @@ -133,18 +114,24 @@ public ModelManager( int minPreviewSize, Duration modelTtl, Setting checkpointIntervalSetting, - EntityColdStarter entityColdStarter, + ADEntityColdStart entityColdStarter, FeatureManager featureManager, MemoryTracker memoryTracker, Settings settings, ClusterService clusterService ) { - this.checkpointDao = checkpointDao; - this.clock = clock; - this.rcfNumTrees = rcfNumTrees; - this.rcfNumSamplesInTree = rcfNumSamplesInTree; - this.rcfTimeDecay = rcfTimeDecay; - this.rcfNumMinSamples = rcfNumMinSamples; + super( + rcfNumTrees, + rcfNumSamplesInTree, + rcfTimeDecay, + rcfNumMinSamples, + entityColdStarter, + memoryTracker, + clock, + featureManager, + checkpointDao + ); + this.thresholdMinPvalue = thresholdMinPvalue; this.minPreviewSize = minPreviewSize; this.modelTtl = modelTtl; @@ -155,12 +142,9 @@ public ModelManager( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); } - this.forests = new TRCFMemoryAwareConcurrentHashmap<>(memoryTracker); + this.forests = new MemoryAwareConcurrentHashmap<>(memoryTracker); this.thresholds = new ConcurrentHashMap<>(); - this.entityColdStarter = entityColdStarter; - this.featureManager = featureManager; - this.memoryTracker = memoryTracker; this.initialAcceptFraction = rcfNumMinSamples * 1.0d / rcfNumSamplesInTree; } @@ -197,10 +181,14 @@ private void getTRcfResult( ) { modelState.setLastUsedTime(clock.instant()); - ThresholdedRandomCutForest trcf = modelState.getModel(); + Optional trcfOptional = modelState.getModel(); + if (trcfOptional.isEmpty()) { + listener.onFailure(new TimeSeriesException("empty model")); + return; + } try { - AnomalyDescriptor result = trcf.process(point, 0); - double[] attribution = normalizeAttribution(trcf.getForest(), result.getRelevantAttribution()); + AnomalyDescriptor result = trcfOptional.get().process(point, 0); + double[] attribution = normalizeAttribution(trcfOptional.get().getForest(), result.getRelevantAttribution()); listener .onResponse( new ThresholdingResult( @@ -285,7 +273,17 @@ private Optional> restoreModelState( } return rcfModel .filter(rcf -> memoryTracker.isHostingAllowed(detectorId, rcf)) - .map(rcf -> ModelState.createSingleEntityModelState(rcf, modelId, detectorId, ModelType.RCF.getName(), clock)); + .map( + rcf -> new ModelState( + rcf, + modelId, + detectorId, + ModelManager.ModelType.TRCF.getName(), + clock, + null, + new ArrayDeque() + ) + ); } private void processRestoredTRcf( @@ -319,13 +317,16 @@ private void processRestoredCheckpoint( ) { logger.info("Restoring checkpoint for {}", modelId); Optional> model = restoreModelState(checkpointModel, modelId, detectorId); - if (model.isPresent()) { - forests.put(modelId, model.get()); - if (model.get().getModel() != null && model.get().getModel().getForest() != null) - listener.onResponse(model.get().getModel().getForest().getTotalUpdates()); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); - } + model.ifPresentOrElse(modelState -> { + forests.put(modelId, modelState); + modelState.getModel().ifPresent(trcf -> { + if (trcf.getForest() != null) { + listener.onResponse(trcf.getForest().getTotalUpdates()); + } else { + listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId)); + } + }); + }, () -> listener.onFailure(new ResourceNotFoundException(detectorId, ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelId))); } /** @@ -355,14 +356,26 @@ private void getThresholdingResult( double score, ActionListener listener ) { - ThresholdingModel threshold = modelState.getModel(); - double grade = threshold.grade(score); - double confidence = threshold.confidence(); - if (score > 0) { - threshold.update(score); + Optional thresholdOptional = modelState.getModel(); + if (thresholdOptional.isPresent()) { + ThresholdingModel threshold = thresholdOptional.get(); + double grade = threshold.grade(score); + double confidence = threshold.confidence(); + if (score > 0) { + threshold.update(score); + } + modelState.setLastUsedTime(clock.instant()); + listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } else { + listener + .onFailure( + new ResourceNotFoundException( + modelState.getConfigId(), + ADCommonMessages.NO_CHECKPOINT_ERR_MSG + modelState.getModelId() + ) + ); } - modelState.setLastUsedTime(clock.instant()); - listener.onResponse(new ThresholdingResult(grade, confidence, score)); + } private void processThresholdCheckpoint( @@ -374,7 +387,15 @@ private void processThresholdCheckpoint( ) { Optional> model = thresholdModel .map( - threshold -> ModelState.createSingleEntityModelState(threshold, modelId, detectorId, ModelType.THRESHOLD.getName(), clock) + threshold -> new ModelState<>( + threshold, + modelId, + detectorId, + ModelManager.ModelType.THRESHOLD.getName(), + clock, + null, + new ArrayDeque() + ) ); if (model.isPresent()) { thresholds.put(modelId, model.get()); @@ -423,8 +444,8 @@ private void stopModel(Map> models, String modelId, Ac Optional> modelState = Optional .ofNullable(models.remove(modelId)) .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)); - if (modelState.isPresent()) { - T model = modelState.get().getModel(); + if (modelState.isPresent() && modelState.get().getModel().isPresent()) { + T model = modelState.get().getModel().get(); if (model instanceof ThresholdedRandomCutForest) { checkpointDao .putTRCFCheckpoint( @@ -459,29 +480,6 @@ public void clear(String detectorId, ActionListener listener) { clearModels(detectorId, forests, ActionListener.wrap(r -> clearModels(detectorId, thresholds, listener), listener::onFailure)); } - private void clearModels(String detectorId, Map models, ActionListener listener) { - Iterator id = models.keySet().iterator(); - clearModelForIterator(detectorId, models, id, listener); - } - - private void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { - if (idIter.hasNext()) { - String modelId = idIter.next(); - if (SingleStreamModelIdMapper.getDetectorIdForModelId(modelId).equals(detectorId)) { - models.remove(modelId); - checkpointDao - .deleteModelCheckpoint( - modelId, - ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) - ); - } else { - clearModelForIterator(detectorId, models, idIter, listener); - } - } else { - listener.onResponse(null); - } - } - /** * Trains and saves cold-start AD models. * @@ -574,13 +572,18 @@ private void maintenanceForIterator( logger.warn("Failed to finish maintenance for model id " + modelId, e); maintenanceForIterator(models, iter, listener); }); - T model = modelState.getModel(); - if (model instanceof ThresholdedRandomCutForest) { - checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); - } else if (model instanceof ThresholdingModel) { - checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + Optional modelOptional = modelState.getModel(); + if (modelOptional.isPresent()) { + T model = modelOptional.get(); + if (model instanceof ThresholdedRandomCutForest) { + checkpointDao.putTRCFCheckpoint(modelId, (ThresholdedRandomCutForest) model, checkpointListener); + } else if (model instanceof ThresholdingModel) { + checkpointDao.putThresholdCheckpoint(modelId, (ThresholdingModel) model, checkpointListener); + } else { + checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + } } else { - checkpointListener.onFailure(new IllegalArgumentException("Unexpected model type")); + maintenanceForIterator(models, iter, listener); } } else { maintenanceForIterator(models, iter, listener); @@ -618,7 +621,7 @@ public List getPreviewResults(double[][] dataPoints, int shi .parallelExecutionEnabled(false) .compact(true) .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .boundingBoxCacheFraction(TimeSeriesSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) .shingleSize(shingleSize) .anomalyRate(1 - this.thresholdMinPvalue) .build(); @@ -648,15 +651,11 @@ public List getPreviewResults(double[][] dataPoints, int shi @Override public Map getModelSize(String detectorId) { Map res = new HashMap<>(); - forests - .entrySet() - .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) - .forEach(entry -> { res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(entry.getValue().getModel())); }); + res.putAll(forests.getModelSize(detectorId)); thresholds .entrySet() .stream() - .filter(entry -> SingleStreamModelIdMapper.getDetectorIdForModelId(entry.getKey()).equals(detectorId)) + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(detectorId)) .forEach(entry -> { res.put(entry.getKey(), (long) memoryTracker.getThresholdModelBytes()); }); return res; } @@ -670,8 +669,8 @@ public Map getModelSize(String detectorId) { public void getTotalUpdates(String modelId, String detectorId, ActionListener listener) { ModelState model = forests.get(modelId); if (model != null) { - if (model.getModel() != null && model.getModel().getForest() != null) { - listener.onResponse(model.getModel().getForest().getTotalUpdates()); + if (model.getModel().isPresent() && model.getModel().get().getForest() != null) { + listener.onResponse(model.getModel().get().getForest().getTotalUpdates()); } else { listener.onResponse(0L); } @@ -685,131 +684,13 @@ public void getTotalUpdates(String modelId, String detectorId, ActionListener modelState, - String modelId, - Entity entity, - int shingleSize - ) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - if (modelState != null) { - EntityModel entityModel = modelState.getModel(); - - if (entityModel == null) { - entityModel = new EntityModel(entity, new ArrayDeque<>(), null); - modelState.setModel(entityModel); - } - - if (!entityModel.getTrcf().isPresent()) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - - if (entityModel.getTrcf().isPresent()) { - result = score(datapoint, modelId, modelState); - } else { - entityModel.addSample(datapoint); - } - } - return result; - } - - public ThresholdingResult score(double[] feature, String modelId, ModelState modelState) { - ThresholdingResult result = new ThresholdingResult(0, 0, 0); - EntityModel model = modelState.getModel(); - try { - if (model != null && model.getTrcf().isPresent()) { - ThresholdedRandomCutForest trcf = model.getTrcf().get(); - Optional.ofNullable(model.getSamples()).ifPresent(q -> { - q.stream().forEach(s -> trcf.process(s, 0)); - q.clear(); - }); - result = toResult(trcf.getForest(), trcf.process(feature, 0)); - } - } catch (Exception e) { - logger - .error( - new ParameterizedMessage( - "Fail to score for [{}]: model Id [{}], feature [{}]", - modelState.getModel().getEntity(), - modelId, - Arrays.toString(feature) - ), - e - ); - throw e; - } finally { - modelState.setLastUsedTime(clock.instant()); - } - return result; - } - - /** - * Instantiate an entity state out of checkpoint. Train models if there are - * enough samples. - * @param checkpoint Checkpoint loaded from index - * @param entity objects to access Entity attributes - * @param modelId Model Id - * @param detectorId Detector Id - * @param shingleSize Shingle size - * - * @return updated model state - * - */ - public ModelState processEntityCheckpoint( - Optional> checkpoint, - Entity entity, - String modelId, - String detectorId, - int shingleSize - ) { - // entity state to instantiate - ModelState modelState = new ModelState<>( - new EntityModel(entity, new ArrayDeque<>(), null), - modelId, - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - if (checkpoint.isPresent()) { - Entry modelToTime = checkpoint.get(); - EntityModel restoredModel = modelToTime.getKey(); - combineSamples(modelState.getModel(), restoredModel); - modelState.setModel(restoredModel); - modelState.setLastCheckpointTime(modelToTime.getValue()); - } - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - if (!model.getTrcf().isPresent() && model.getSamples() != null && model.getSamples().size() >= rcfNumMinSamples) { - entityColdStarter.trainModelFromExistingSamples(modelState, shingleSize); - } - return modelState; - } - - private void combineSamples(EntityModel fromModel, EntityModel toModel) { - Queue samples = fromModel.getSamples(); - while (samples.peek() != null) { - toModel.addSample(samples.poll()); - } + @Override + protected ThresholdingResult createEmptyResult() { + return new ThresholdingResult(0, 0, 0); } - private ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { + @Override + protected ThresholdingResult toResult(RandomCutForest rcf, AnomalyDescriptor anomalyDescriptor) { return new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), anomalyDescriptor.getDataConfidence(), diff --git a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java b/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java deleted file mode 100644 index 3f198285f..000000000 --- a/src/main/java/org/opensearch/ad/ml/EntityColdStarter.java +++ /dev/null @@ -1,755 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; - -import java.time.Clock; -import java.time.Duration; -import java.time.Instant; -import java.util.AbstractMap.SimpleImmutableEntry; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.Queue; -import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.core.util.Throwables; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.ActionListener; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.CleanState; -import org.opensearch.ad.MaintenanceState; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.caching.DoorKeeper; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.common.settings.Settings; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.dataprocessor.Imputer; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.settings.TimeSeriesSettings; - -import com.amazon.randomcutforest.config.Precision; -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * Training models for HCAD detectors - * - */ -public class EntityColdStarter implements MaintenanceState, CleanState { - private static final Logger logger = LogManager.getLogger(EntityColdStarter.class); - private final Clock clock; - private final ThreadPool threadPool; - private final NodeStateManager nodeStateManager; - private final int rcfSampleSize; - private final int numberOfTrees; - private final double rcfTimeDecay; - private final int numMinSamples; - private final double thresholdMinPvalue; - private final int defaulStrideLength; - private final int defaultNumberOfSamples; - private final Imputer imputer; - private final SearchFeatureDao searchFeatureDao; - private Instant lastThrottledColdStartTime; - private final FeatureManager featureManager; - private int coolDownMinutes; - // A bloom filter checked before cold start to ensure we don't repeatedly - // retry cold start of the same model. - // keys are detector ids. - private Map doorKeepers; - private final Duration modelTtl; - private final CheckpointWriteWorker checkpointWriteQueue; - // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. - // this is mainly used for testing to make sure the model we trained and the reference rcf produce - // the same results - private final long rcfSeed; - private final int maxRoundofColdStart; - private final double initialAcceptFraction; - - /** - * Constructor - * - * @param clock UTC clock - * @param threadPool Accessor to different threadpools - * @param nodeStateManager Storing node state - * @param rcfSampleSize The sample size used by stream samplers in this forest - * @param numberOfTrees The number of trees in this forest. - * @param rcfTimeDecay rcf samples time decay constant - * @param numMinSamples The number of points required by stream samplers before - * results are returned. - * @param defaultSampleStride default sample distances measured in detector intervals. - * @param defaultTrainSamples Default train samples to collect. - * @param imputer Used to generate data points between samples. - * @param searchFeatureDao Used to issue ES queries. - * @param thresholdMinPvalue min P-value for thresholding - * @param featureManager Used to create features for models. - * @param settings ES settings accessor - * @param modelTtl time-to-live before last access time of the cold start cache. - * We have a cache to record entities that have run cold starts to avoid - * repeated unsuccessful cold start. - * @param checkpointWriteQueue queue to insert model checkpoints - * @param rcfSeed rcf random seed - * @param maxRoundofColdStart max number of rounds of cold start - */ - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int defaultSampleStride, - int defaultTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - long rcfSeed, - int maxRoundofColdStart - ) { - this.clock = clock; - this.lastThrottledColdStartTime = Instant.MIN; - this.threadPool = threadPool; - this.nodeStateManager = nodeStateManager; - this.rcfSampleSize = rcfSampleSize; - this.numberOfTrees = numberOfTrees; - this.rcfTimeDecay = rcfTimeDecay; - this.numMinSamples = numMinSamples; - this.defaulStrideLength = defaultSampleStride; - this.defaultNumberOfSamples = defaultTrainSamples; - this.imputer = imputer; - this.searchFeatureDao = searchFeatureDao; - this.thresholdMinPvalue = thresholdMinPvalue; - this.featureManager = featureManager; - this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); - this.doorKeepers = new ConcurrentHashMap<>(); - this.modelTtl = modelTtl; - this.checkpointWriteQueue = checkpointWriteQueue; - this.rcfSeed = rcfSeed; - this.maxRoundofColdStart = maxRoundofColdStart; - this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; - } - - public EntityColdStarter( - Clock clock, - ThreadPool threadPool, - NodeStateManager nodeStateManager, - int rcfSampleSize, - int numberOfTrees, - double rcfTimeDecay, - int numMinSamples, - int maxSampleStride, - int maxTrainSamples, - Imputer imputer, - SearchFeatureDao searchFeatureDao, - double thresholdMinPvalue, - FeatureManager featureManager, - Settings settings, - Duration modelTtl, - CheckpointWriteWorker checkpointWriteQueue, - int maxRoundofColdStart - ) { - this( - clock, - threadPool, - nodeStateManager, - rcfSampleSize, - numberOfTrees, - rcfTimeDecay, - numMinSamples, - maxSampleStride, - maxTrainSamples, - imputer, - searchFeatureDao, - thresholdMinPvalue, - featureManager, - settings, - modelTtl, - checkpointWriteQueue, - -1, - maxRoundofColdStart - ); - } - - /** - * Training model for an entity - * @param modelId model Id corresponding to the entity - * @param entity the entity's information - * @param detectorId the detector Id corresponding to the entity - * @param modelState model state associated with the entity - * @param listener call back to call after cold start - */ - private void coldStart( - String modelId, - Entity entity, - String detectorId, - ModelState modelState, - AnomalyDetector detector, - ActionListener listener - ) { - logger.debug("Trigger cold start for {}", modelId); - - if (modelState == null || entity == null) { - listener - .onFailure( - new IllegalArgumentException( - String - .format( - Locale.ROOT, - "Cannot have empty model state or entity: model state [%b], entity [%b]", - modelState == null, - entity == null - ) - ) - ); - return; - } - - if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { - listener.onResponse(null); - return; - } - - boolean earlyExit = true; - try { - DoorKeeper doorKeeper = doorKeepers - .computeIfAbsent( - detectorId, - id -> { - // reset every 60 intervals - return new DoorKeeper( - TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, - TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), - clock - ); - } - ); - - // Won't retry cold start within 60 intervals for an entity - if (doorKeeper.mightContain(modelId)) { - return; - } - - doorKeeper.put(modelId); - - ActionListener>> coldStartCallBack = ActionListener.wrap(trainingData -> { - try { - if (trainingData.isPresent()) { - List dataPoints = trainingData.get(); - extractTrainSamples(dataPoints, modelId, modelState); - Queue samples = modelState.getModel().getSamples(); - // only train models if we have enough samples - if (samples.size() >= numMinSamples) { - // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by - // multiple places so I want to make the saving model implicit just in case I forgot. - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - logger.info("Succeeded in training entity: {}", modelId); - } else { - // save to checkpoint - checkpointWriteQueue.write(modelState, true, RequestPriority.MEDIUM); - logger.info("Not enough data to train entity: {}, currently we have {}", modelId, samples.size()); - } - } else { - logger.info("Cannot get training data for {}", modelId); - } - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - }, exception -> { - try { - logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); - Throwable cause = Throwables.getRootCause(exception); - if (ExceptionUtil.isOverloaded(cause)) { - logger.error("too many requests"); - lastThrottledColdStartTime = Instant.now(); - } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { - // e.g., cannot find anomaly detector - nodeStateManager.setException(detectorId, exception); - } else { - nodeStateManager.setException(detectorId, new TimeSeriesException(detectorId, cause)); - } - listener.onFailure(exception); - } catch (Exception e) { - listener.onFailure(e); - } - }); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> getEntityColdStartData( - detectorId, - entity, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - coldStartCallBack, - false - ) - ) - ); - earlyExit = false; - } finally { - if (earlyExit) { - listener.onResponse(null); - } - } - } - - /** - * Train model using given data points and save the trained model. - * - * @param dataPoints Queue of continuous data points, in ascending order of timestamps - * @param entity Entity instance - * @param entityState Entity state associated with the model Id - */ - private void trainModelFromDataSegments( - Queue dataPoints, - Entity entity, - ModelState entityState, - int shingleSize - ) { - if (dataPoints == null || dataPoints.size() == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - - double[] firstPoint = dataPoints.peek(); - if (firstPoint == null || firstPoint.length == 0) { - throw new IllegalArgumentException("Data points must not be empty."); - } - int dimensions = firstPoint.length * shingleSize; - ThresholdedRandomCutForest.Builder rcfBuilder = ThresholdedRandomCutForest - .builder() - .dimensions(dimensions) - .sampleSize(rcfSampleSize) - .numberOfTrees(numberOfTrees) - .timeDecay(rcfTimeDecay) - .outputAfter(numMinSamples) - .initialAcceptFraction(initialAcceptFraction) - .parallelExecutionEnabled(false) - .compact(true) - .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - // same with dimension for opportunistic memory saving - // Usually, we use it as shingleSize(dimension). When a new point comes in, we will - // look at the point store if there is any overlapping. Say the previously-stored - // vector is x1, x2, x3, x4, now we add x3, x4, x5, x6. RCF will recognize - // overlapping x3, x4, and only store x5, x6. - .shingleSize(shingleSize) - .internalShinglingEnabled(true) - .anomalyRate(1 - this.thresholdMinPvalue); - - if (rcfSeed > 0) { - rcfBuilder.randomSeed(rcfSeed); - } - ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest(rcfBuilder); - - while (!dataPoints.isEmpty()) { - trcf.process(dataPoints.poll(), 0); - } - - EntityModel model = entityState.getModel(); - if (model == null) { - model = new EntityModel(entity, new ArrayDeque<>(), null); - } - model.setTrcf(trcf); - - entityState.setLastUsedTime(clock.instant()); - - // save to checkpoint - checkpointWriteQueue.write(entityState, true, RequestPriority.MEDIUM); - } - - /** - * Get training data for an entity. - * - * We first note the maximum and minimum timestamp, and sample at most 24 points - * (with 60 points apart between two neighboring samples) between those minimum - * and maximum timestamps. Samples can be missing. We only interpolate points - * between present neighboring samples. We then transform samples and interpolate - * points to shingles. Finally, full shingles will be used for cold start. - * - * @param detectorId detector Id - * @param entity the entity's information - * @param listener listener to return training data - */ - private void getEntityColdStartData(String detectorId, Entity entity, ActionListener>> listener) { - ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { - if (!detectorOp.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - List coldStartData = new ArrayList<>(); - AnomalyDetector detector = detectorOp.get(); - - ActionListener> minTimeListener = ActionListener.wrap(earliest -> { - if (earliest.isPresent()) { - long startTimeMs = earliest.get().longValue(); - - // End time uses milliseconds as start time is assumed to be in milliseconds. - // Opensearch uses a set of preconfigured formats to recognize and parse these - // strings into a long value - // representing milliseconds-since-the-epoch in UTC. - // More on https://tinyurl.com/wub4fk92 - - long endTimeMs = clock.millis(); - Pair params = selectRangeParam(detector); - int stride = params.getLeft(); - int numberOfSamples = params.getRight(); - - // we start with round 0 - getFeatures(listener, 0, coldStartData, detector, entity, stride, numberOfSamples, startTimeMs, endTimeMs); - } else { - listener.onResponse(Optional.empty()); - } - }, listener::onFailure); - - searchFeatureDao - .getEntityMinDataTime( - detector, - entity, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, minTimeListener, false) - ); - - }, listener::onFailure); - - nodeStateManager - .getAnomalyDetector( - detectorId, - new ThreadedActionListener<>(logger, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, getDetectorListener, false) - ); - } - - private void getFeatures( - ActionListener>> listener, - int round, - List lastRoundColdStartData, - AnomalyDetector detector, - Entity entity, - int stride, - int numberOfSamples, - long startTimeMs, - long endTimeMs - ) { - if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < detector.getIntervalInMilliseconds()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - // create ranges in desending order, we will reorder it in ascending order - // in Opensearch's response - List> sampleRanges = getTrainSampleRanges(detector, startTimeMs, endTimeMs, stride, numberOfSamples); - - if (sampleRanges.isEmpty()) { - listener.onResponse(Optional.of(lastRoundColdStartData)); - return; - } - - ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { - // storing lastSample = null; - List currentRoundColdStartData = new ArrayList<>(); - - // featuresSamples are in ascending order of time. - for (int i = 0; i < featureSamples.size(); i++) { - Optional featuresOptional = featureSamples.get(i); - if (featuresOptional.isPresent()) { - // we only need the most recent two samples - // For the missing samples we use linear interpolation as well. - // Denote the Samples S0, S1, ... as samples in reverse order of time. - // Each [Si​,Si−1​]corresponds to strideLength * detector interval. - // If we got samples for S0, S1, S4 (both S2 and S3 are missing), then - // we interpolate the [S4,S1] into 3*strideLength pieces. - if (lastSample != null) { - // right sample has index i and feature featuresOptional.get() - int numInterpolants = (i - lastSample.getLeft()) * stride + 1; - double[][] points = featureManager - .transpose( - imputer - .impute( - featureManager.transpose(new double[][] { lastSample.getRight(), featuresOptional.get() }), - numInterpolants - ) - ); - // the last point will be included in the next iteration or we process - // it in the end. We don't want to repeatedly include the samples twice. - currentRoundColdStartData.add(Arrays.copyOfRange(points, 0, points.length - 1)); - } - lastSample = Pair.of(i, featuresOptional.get()); - } - } - - if (lastSample != null) { - currentRoundColdStartData.add(new double[][] { lastSample.getRight() }); - } - if (lastRoundColdStartData.size() > 0) { - currentRoundColdStartData.addAll(lastRoundColdStartData); - } - - // If the first round of probe provides (32+shingleSize) points (note that if S0 is - // missing or all Si​ for some i > N is missing then we would miss a lot of points. - // Otherwise we can issue another round of query — if there is any sample in the - // second round then we would have 32 + shingleSize points. If there is no sample - // in the second round then we should wait for real data. - if (calculateColdStartDataSize(currentRoundColdStartData) >= detector.getShingleSize() + numMinSamples - || round + 1 >= maxRoundofColdStart) { - listener.onResponse(Optional.of(currentRoundColdStartData)); - } else { - // the last sample's start time is the endTimeMs of next round of probe. - long lastSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); - getFeatures( - listener, - round + 1, - currentRoundColdStartData, - detector, - entity, - stride, - numberOfSamples, - startTimeMs, - lastSampleStartTime - ); - } - }, listener::onFailure); - - try { - searchFeatureDao - .getColdStartSamplesForPeriods( - detector, - sampleRanges, - entity, - // Accept empty bucket. - // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 - // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the - // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do - // that, we will have issues with legitimate interpretations of 0. - true, - new ThreadedActionListener<>( - logger, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - getFeaturelistener, - false - ) - ); - } catch (Exception e) { - listener.onFailure(e); - } - } - - private int calculateColdStartDataSize(List coldStartData) { - int size = 0; - for (int i = 0; i < coldStartData.size(); i++) { - size += coldStartData.get(i).length; - } - return size; - } - - /** - * Select strideLength and numberOfSamples, where stride is the number of intervals - * between two samples and trainSamples is training samples to fetch. If we disable - * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; - * - * Algorithm: - * - * delta is the length of the detector interval in minutes. - * - * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 - * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 - * points — which is only good. This step tries to match data with hourly pattern. - * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. - * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes - *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- - * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. - * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. - * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. - * @return the chosen strideLength and numberOfSamples - */ - private Pair selectRangeParam(AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); - if (ADEnabledSetting.isInterpolationInColdStartEnabled()) { - long delta = detector.getIntervalInMinutes(); - - int strideLength = defaulStrideLength; - int numberOfSamples = defaultNumberOfSamples; - if (delta <= 30 && 60 % delta == 0) { - strideLength = (int) (60 / delta); - numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; - } else { - strideLength = 1; - numberOfSamples = shingleSize + numMinSamples; - } - return Pair.of(strideLength, numberOfSamples); - } else { - return Pair.of(1, shingleSize + numMinSamples); - } - - } - - /** - * Get train samples within a time range. - * - * @param detector accessor to detector config - * @param startMilli range start - * @param endMilli range end - * @param stride the number of intervals between two samples - * @param numberOfSamples maximum training samples to fetch - * @return list of sample time ranges - */ - private List> getTrainSampleRanges( - AnomalyDetector detector, - long startMilli, - long endMilli, - int stride, - int numberOfSamples - ) { - long bucketSize = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMillis(); - int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); - // adjust if numStrides is more than the max samples - int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); - List> sampleRanges = Stream - .iterate(endMilli, i -> i - stride * bucketSize) - .limit(numStrides) - .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) - .collect(Collectors.toList()); - return sampleRanges; - } - - /** - * Train models for the given entity - * @param entity The entity info - * @param detectorId Detector Id - * @param modelState Model state associated with the entity - * @param listener callback before the method returns whenever EntityColdStarter - * finishes training or encounters exceptions. The listener helps notify the - * cold start queue to pull another request (if any) to execute. - */ - public void trainModel(Entity entity, String detectorId, ModelState modelState, ActionListener listener) { - nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - logger.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - listener.onFailure(new TimeSeriesException(detectorId, "fail to find detector")); - return; - } - - AnomalyDetector detector = detectorOptional.get(); - - Queue samples = modelState.getModel().getSamples(); - String modelId = modelState.getModelId(); - - if (samples.size() < this.numMinSamples) { - // we cannot get last RCF score since cold start happens asynchronously - coldStart(modelId, entity, detectorId, modelState, detector, listener); - } else { - try { - trainModelFromDataSegments(samples, entity, modelState, detector.getShingleSize()); - listener.onResponse(null); - } catch (Exception e) { - listener.onFailure(e); - } - } - - }, listener::onFailure)); - } - - public void trainModelFromExistingSamples(ModelState modelState, int shingleSize) { - if (modelState == null || modelState.getModel() == null || modelState.getModel().getSamples() == null) { - return; - } - - EntityModel model = modelState.getModel(); - Queue samples = model.getSamples(); - if (samples.size() >= this.numMinSamples) { - try { - trainModelFromDataSegments(samples, model.getEntity().orElse(null), modelState, shingleSize); - } catch (Exception e) { - // e.g., exception from rcf. We can do nothing except logging the error - // We won't retry training for the same entity in the cooldown period - // (60 detector intervals). - logger.error("Unexpected training error", e); - } - - } - } - - /** - * Extract training data and put them into ModelState - * - * @param coldstartDatapoints training data generated from cold start - * @param modelId model Id - * @param modelState entity State - */ - private void extractTrainSamples(List coldstartDatapoints, String modelId, ModelState modelState) { - if (coldstartDatapoints == null || coldstartDatapoints.size() == 0 || modelState == null) { - return; - } - - EntityModel model = modelState.getModel(); - if (model == null) { - model = new EntityModel(null, new ArrayDeque<>(), null); - modelState.setModel(model); - } - - Queue newSamples = new ArrayDeque<>(); - for (double[][] consecutivePoints : coldstartDatapoints) { - for (int i = 0; i < consecutivePoints.length; i++) { - newSamples.add(consecutivePoints[i]); - } - } - - model.setSamples(newSamples); - } - - @Override - public void maintenance() { - doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); - DoorKeeper doorKeeper = doorKeeperEntry.getValue(); - if (doorKeeper.expired(modelTtl)) { - doorKeepers.remove(detectorId); - } else { - doorKeeper.maintenance(); - } - }); - } - - @Override - public void clear(String detectorId) { - doorKeepers.remove(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java deleted file mode 100644 index 7b7b1fe7d..000000000 --- a/src/main/java/org/opensearch/ad/ml/TRCFMemoryAwareConcurrentHashmap.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ml; - -import java.util.concurrent.ConcurrentHashMap; - -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.MemoryTracker.Origin; - -import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - -/** - * A customized ConcurrentHashMap that can automatically consume and release memory. - * This enables minimum change to our single-entity code as we just have to replace - * the map implementation. - * - * Note: this is mainly used for single-entity detectors. - */ -public class TRCFMemoryAwareConcurrentHashmap extends ConcurrentHashMap> { - private final MemoryTracker memoryTracker; - - public TRCFMemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { - this.memoryTracker = memoryTracker; - } - - @Override - public ModelState remove(Object key) { - ModelState deletedModelState = super.remove(key); - if (deletedModelState != null && deletedModelState.getModel() != null) { - long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel()); - memoryTracker.releaseMemory(memoryToRelease, true, Origin.SINGLE_ENTITY_DETECTOR); - } - return deletedModelState; - } - - @Override - public ModelState put(K key, ModelState value) { - ModelState previousAssociatedState = super.put(key, value); - if (value != null && value.getModel() != null) { - long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel()); - memoryTracker.consumeMemory(memoryToConsume, true, Origin.SINGLE_ENTITY_DETECTOR); - } - return previousAssociatedState; - } -} diff --git a/src/main/java/org/opensearch/ad/model/ADTask.java b/src/main/java/org/opensearch/ad/model/ADTask.java index 0004f9640..2b72c8197 100644 --- a/src/main/java/org/opensearch/ad/model/ADTask.java +++ b/src/main/java/org/opensearch/ad/model/ADTask.java @@ -11,7 +11,6 @@ package org.opensearch.ad.model; -import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; @@ -20,13 +19,12 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.timeseries.annotation.Generated; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TimeSeriesTask; import org.opensearch.timeseries.util.ParseUtils; import com.google.common.base.Objects; @@ -34,64 +32,21 @@ /** * One anomaly detection task means one detector starts to run until stopped. */ -public class ADTask implements ToXContentObject, Writeable { +public class ADTask extends TimeSeriesTask { - public static final String TASK_ID_FIELD = "task_id"; - public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; - public static final String STARTED_BY_FIELD = "started_by"; - public static final String STOPPED_BY_FIELD = "stopped_by"; - public static final String ERROR_FIELD = "error"; - public static final String STATE_FIELD = "state"; public static final String DETECTOR_ID_FIELD = "detector_id"; - public static final String TASK_PROGRESS_FIELD = "task_progress"; - public static final String INIT_PROGRESS_FIELD = "init_progress"; - public static final String CURRENT_PIECE_FIELD = "current_piece"; - public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; - public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; - public static final String IS_LATEST_FIELD = "is_latest"; - public static final String TASK_TYPE_FIELD = "task_type"; - public static final String CHECKPOINT_ID_FIELD = "checkpoint_id"; - public static final String COORDINATING_NODE_FIELD = "coordinating_node"; - public static final String WORKER_NODE_FIELD = "worker_node"; public static final String DETECTOR_FIELD = "detector"; public static final String DETECTION_DATE_RANGE_FIELD = "detection_date_range"; - public static final String ENTITY_FIELD = "entity"; - public static final String PARENT_TASK_ID_FIELD = "parent_task_id"; - public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; - public static final String USER_FIELD = "user"; - public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; - private String taskId = null; - private Instant lastUpdateTime = null; - private String startedBy = null; - private String stoppedBy = null; - private String error = null; - private String state = null; - private String detectorId = null; - private Float taskProgress = null; - private Float initProgress = null; - private Instant currentPiece = null; - private Instant executionStartTime = null; - private Instant executionEndTime = null; - private Boolean isLatest = null; - private String taskType = null; - private String checkpointId = null; private AnomalyDetector detector = null; - - private String coordinatingNode = null; - private String workerNode = null; private DateRange detectionDateRange = null; - private Entity entity = null; - private String parentTaskId = null; - private Integer estimatedMinutesLeft = null; - private User user = null; private ADTask() {} public ADTask(StreamInput input) throws IOException { this.taskId = input.readOptionalString(); this.taskType = input.readOptionalString(); - this.detectorId = input.readOptionalString(); + this.configId = input.readOptionalString(); if (input.readBoolean()) { this.detector = new AnomalyDetector(input); } else { @@ -137,7 +92,7 @@ public ADTask(StreamInput input) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(taskId); out.writeOptionalString(taskType); - out.writeOptionalString(detectorId); + out.writeOptionalString(configId); if (detector != null) { out.writeBoolean(true); detector.writeTo(out); @@ -185,175 +140,40 @@ public static Builder builder() { return new Builder(); } - public boolean isHistoricalTask() { - return taskType.startsWith(HISTORICAL_TASK_PREFIX); - } - + @Override public boolean isEntityTask() { - return ADTaskType.HISTORICAL_HC_ENTITY.name().equals(taskType); - } - - /** - * Get detector level task id. If a task has no parent task, the task is detector level task. - * @return detector level task id - */ - public String getDetectorLevelTaskId() { - return getParentTaskId() != null ? getParentTaskId() : getTaskId(); - } - - public boolean isDone() { - return !NOT_ENDED_STATES.contains(this.getState()); + return ADTaskType.AD_HISTORICAL_HC_ENTITY.name().equals(taskType); } - public static class Builder { - private String taskId = null; - private String taskType = null; + public static class Builder extends TimeSeriesTask.Builder { private String detectorId = null; private AnomalyDetector detector = null; - private String state = null; - private Float taskProgress = null; - private Float initProgress = null; - private Instant currentPiece = null; - private Instant executionStartTime = null; - private Instant executionEndTime = null; - private Boolean isLatest = null; - private String error = null; - private String checkpointId = null; - private Instant lastUpdateTime = null; - private String startedBy = null; - private String stoppedBy = null; - private String coordinatingNode = null; - private String workerNode = null; private DateRange detectionDateRange = null; - private Entity entity = null; - private String parentTaskId; - private Integer estimatedMinutesLeft; - private User user = null; public Builder() {} - public Builder taskId(String taskId) { - this.taskId = taskId; - return this; - } - - public Builder lastUpdateTime(Instant lastUpdateTime) { - this.lastUpdateTime = lastUpdateTime; - return this; - } - - public Builder startedBy(String startedBy) { - this.startedBy = startedBy; - return this; - } - - public Builder stoppedBy(String stoppedBy) { - this.stoppedBy = stoppedBy; - return this; - } - - public Builder error(String error) { - this.error = error; - return this; - } - - public Builder state(String state) { - this.state = state; - return this; - } - public Builder detectorId(String detectorId) { this.detectorId = detectorId; return this; } - public Builder taskProgress(Float taskProgress) { - this.taskProgress = taskProgress; - return this; - } - - public Builder initProgress(Float initProgress) { - this.initProgress = initProgress; - return this; - } - - public Builder currentPiece(Instant currentPiece) { - this.currentPiece = currentPiece; - return this; - } - - public Builder executionStartTime(Instant executionStartTime) { - this.executionStartTime = executionStartTime; - return this; - } - - public Builder executionEndTime(Instant executionEndTime) { - this.executionEndTime = executionEndTime; - return this; - } - - public Builder isLatest(Boolean isLatest) { - this.isLatest = isLatest; - return this; - } - - public Builder taskType(String taskType) { - this.taskType = taskType; - return this; - } - - public Builder checkpointId(String checkpointId) { - this.checkpointId = checkpointId; - return this; - } - public Builder detector(AnomalyDetector detector) { this.detector = detector; return this; } - public Builder coordinatingNode(String coordinatingNode) { - this.coordinatingNode = coordinatingNode; - return this; - } - - public Builder workerNode(String workerNode) { - this.workerNode = workerNode; - return this; - } - public Builder detectionDateRange(DateRange detectionDateRange) { this.detectionDateRange = detectionDateRange; return this; } - public Builder entity(Entity entity) { - this.entity = entity; - return this; - } - - public Builder parentTaskId(String parentTaskId) { - this.parentTaskId = parentTaskId; - return this; - } - - public Builder estimatedMinutesLeft(Integer estimatedMinutesLeft) { - this.estimatedMinutesLeft = estimatedMinutesLeft; - return this; - } - - public Builder user(User user) { - this.user = user; - return this; - } - public ADTask build() { ADTask adTask = new ADTask(); adTask.taskId = this.taskId; adTask.lastUpdateTime = this.lastUpdateTime; adTask.error = this.error; adTask.state = this.state; - adTask.detectorId = this.detectorId; + adTask.configId = this.configId; adTask.taskProgress = this.taskProgress; adTask.initProgress = this.initProgress; adTask.currentPiece = this.currentPiece; @@ -381,56 +201,9 @@ public ADTask build() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { XContentBuilder xContentBuilder = builder.startObject(); - if (taskId != null) { - xContentBuilder.field(TASK_ID_FIELD, taskId); - } - if (lastUpdateTime != null) { - xContentBuilder.field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); - } - if (startedBy != null) { - xContentBuilder.field(STARTED_BY_FIELD, startedBy); - } - if (stoppedBy != null) { - xContentBuilder.field(STOPPED_BY_FIELD, stoppedBy); - } - if (error != null) { - xContentBuilder.field(ERROR_FIELD, error); - } - if (state != null) { - xContentBuilder.field(STATE_FIELD, state); - } - if (detectorId != null) { - xContentBuilder.field(DETECTOR_ID_FIELD, detectorId); - } - if (taskProgress != null) { - xContentBuilder.field(TASK_PROGRESS_FIELD, taskProgress); - } - if (initProgress != null) { - xContentBuilder.field(INIT_PROGRESS_FIELD, initProgress); - } - if (currentPiece != null) { - xContentBuilder.field(CURRENT_PIECE_FIELD, currentPiece.toEpochMilli()); - } - if (executionStartTime != null) { - xContentBuilder.field(EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); - } - if (executionEndTime != null) { - xContentBuilder.field(EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); - } - if (isLatest != null) { - xContentBuilder.field(IS_LATEST_FIELD, isLatest); - } - if (taskType != null) { - xContentBuilder.field(TASK_TYPE_FIELD, taskType); - } - if (checkpointId != null) { - xContentBuilder.field(CHECKPOINT_ID_FIELD, checkpointId); - } - if (coordinatingNode != null) { - xContentBuilder.field(COORDINATING_NODE_FIELD, coordinatingNode); - } - if (workerNode != null) { - xContentBuilder.field(WORKER_NODE_FIELD, workerNode); + xContentBuilder = super.toXContent(xContentBuilder, params); + if (configId != null) { + xContentBuilder.field(DETECTOR_ID_FIELD, configId); } if (detector != null) { xContentBuilder.field(DETECTOR_FIELD, detector); @@ -438,18 +211,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (detectionDateRange != null) { xContentBuilder.field(DETECTION_DATE_RANGE_FIELD, detectionDateRange); } - if (entity != null) { - xContentBuilder.field(ENTITY_FIELD, entity); - } - if (parentTaskId != null) { - xContentBuilder.field(PARENT_TASK_ID_FIELD, parentTaskId); - } - if (estimatedMinutesLeft != null) { - xContentBuilder.field(ESTIMATED_MINUTES_LEFT_FIELD, estimatedMinutesLeft); - } - if (user != null) { - xContentBuilder.field(USER_FIELD, user); - } return xContentBuilder.endObject(); } @@ -488,73 +249,73 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept parser.nextToken(); switch (fieldName) { - case LAST_UPDATE_TIME_FIELD: + case TimeSeriesTask.LAST_UPDATE_TIME_FIELD: lastUpdateTime = ParseUtils.toInstant(parser); break; - case STARTED_BY_FIELD: + case TimeSeriesTask.STARTED_BY_FIELD: startedBy = parser.text(); break; - case STOPPED_BY_FIELD: + case TimeSeriesTask.STOPPED_BY_FIELD: stoppedBy = parser.text(); break; - case ERROR_FIELD: + case TimeSeriesTask.ERROR_FIELD: error = parser.text(); break; - case STATE_FIELD: + case TimeSeriesTask.STATE_FIELD: state = parser.text(); break; case DETECTOR_ID_FIELD: detectorId = parser.text(); break; - case TASK_PROGRESS_FIELD: + case TimeSeriesTask.TASK_PROGRESS_FIELD: taskProgress = parser.floatValue(); break; - case INIT_PROGRESS_FIELD: + case TimeSeriesTask.INIT_PROGRESS_FIELD: initProgress = parser.floatValue(); break; - case CURRENT_PIECE_FIELD: + case TimeSeriesTask.CURRENT_PIECE_FIELD: currentPiece = ParseUtils.toInstant(parser); break; - case EXECUTION_START_TIME_FIELD: + case TimeSeriesTask.EXECUTION_START_TIME_FIELD: executionStartTime = ParseUtils.toInstant(parser); break; - case EXECUTION_END_TIME_FIELD: + case TimeSeriesTask.EXECUTION_END_TIME_FIELD: executionEndTime = ParseUtils.toInstant(parser); break; - case IS_LATEST_FIELD: + case TimeSeriesTask.IS_LATEST_FIELD: isLatest = parser.booleanValue(); break; - case TASK_TYPE_FIELD: + case TimeSeriesTask.TASK_TYPE_FIELD: taskType = parser.text(); break; - case CHECKPOINT_ID_FIELD: + case TimeSeriesTask.CHECKPOINT_ID_FIELD: checkpointId = parser.text(); break; case DETECTOR_FIELD: detector = AnomalyDetector.parse(parser); break; - case TASK_ID_FIELD: + case TimeSeriesTask.TASK_ID_FIELD: parsedTaskId = parser.text(); break; - case COORDINATING_NODE_FIELD: + case TimeSeriesTask.COORDINATING_NODE_FIELD: coordinatingNode = parser.text(); break; - case WORKER_NODE_FIELD: + case TimeSeriesTask.WORKER_NODE_FIELD: workerNode = parser.text(); break; case DETECTION_DATE_RANGE_FIELD: detectionDateRange = DateRange.parse(parser); break; - case ENTITY_FIELD: + case TimeSeriesTask.ENTITY_FIELD: entity = Entity.parse(parser); break; - case PARENT_TASK_ID_FIELD: + case TimeSeriesTask.PARENT_TASK_ID_FIELD: parentTaskId = parser.text(); break; - case ESTIMATED_MINUTES_LEFT_FIELD: + case TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD: estimatedMinutesLeft = parser.intValue(); break; - case USER_FIELD: + case TimeSeriesTask.USER_FIELD: user = User.parse(parser); break; default: @@ -613,185 +374,37 @@ public static ADTask parse(XContentParser parser, String taskId) throws IOExcept @Generated @Override - public boolean equals(Object o) { - if (this == o) + public boolean equals(Object other) { + if (this == other) return true; - if (o == null || getClass() != o.getClass()) + if (other == null || getClass() != other.getClass()) return false; - ADTask that = (ADTask) o; - return Objects.equal(getTaskId(), that.getTaskId()) - && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) - && Objects.equal(getStartedBy(), that.getStartedBy()) - && Objects.equal(getStoppedBy(), that.getStoppedBy()) - && Objects.equal(getError(), that.getError()) - && Objects.equal(getState(), that.getState()) - && Objects.equal(getId(), that.getId()) - && Objects.equal(getTaskProgress(), that.getTaskProgress()) - && Objects.equal(getInitProgress(), that.getInitProgress()) - && Objects.equal(getCurrentPiece(), that.getCurrentPiece()) - && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) - && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) - && Objects.equal(getLatest(), that.getLatest()) - && Objects.equal(getTaskType(), that.getTaskType()) - && Objects.equal(getCheckpointId(), that.getCheckpointId()) - && Objects.equal(getCoordinatingNode(), that.getCoordinatingNode()) - && Objects.equal(getWorkerNode(), that.getWorkerNode()) + ADTask that = (ADTask) other; + return super.equals(that) + && Objects.equal(getConfigId(), that.getConfigId()) && Objects.equal(getDetector(), that.getDetector()) - && Objects.equal(getDetectionDateRange(), that.getDetectionDateRange()) - && Objects.equal(getEntity(), that.getEntity()) - && Objects.equal(getParentTaskId(), that.getParentTaskId()) - && Objects.equal(getEstimatedMinutesLeft(), that.getEstimatedMinutesLeft()) - && Objects.equal(getUser(), that.getUser()); + && Objects.equal(getDetectionDateRange(), that.getDetectionDateRange()); } @Generated @Override public int hashCode() { - return Objects - .hashCode( - taskId, - lastUpdateTime, - startedBy, - stoppedBy, - error, - state, - detectorId, - taskProgress, - initProgress, - currentPiece, - executionStartTime, - executionEndTime, - isLatest, - taskType, - checkpointId, - coordinatingNode, - workerNode, - detector, - detectionDateRange, - entity, - parentTaskId, - estimatedMinutesLeft, - user - ); - } - - public String getTaskId() { - return taskId; - } - - public void setTaskId(String taskId) { - this.taskId = taskId; - } - - public Instant getLastUpdateTime() { - return lastUpdateTime; - } - - public String getStartedBy() { - return startedBy; - } - - public String getStoppedBy() { - return stoppedBy; - } - - public String getError() { - return error; - } - - public void setError(String error) { - this.error = error; - } - - public String getState() { - return state; - } - - public void setState(String state) { - this.state = state; - } - - public String getId() { - return detectorId; - } - - public Float getTaskProgress() { - return taskProgress; - } - - public Float getInitProgress() { - return initProgress; - } - - public Instant getCurrentPiece() { - return currentPiece; - } - - public Instant getExecutionStartTime() { - return executionStartTime; - } - - public Instant getExecutionEndTime() { - return executionEndTime; - } - - public Boolean getLatest() { - return isLatest; - } - - public String getTaskType() { - return taskType; - } - - public String getCheckpointId() { - return checkpointId; + int superHashCode = super.hashCode(); + int hash = Objects.hashCode(configId, detector, detectionDateRange); + hash += 89 * superHashCode; + return hash; } public AnomalyDetector getDetector() { return detector; } - public String getCoordinatingNode() { - return coordinatingNode; - } - - public String getWorkerNode() { - return workerNode; - } - public DateRange getDetectionDateRange() { return detectionDateRange; } - public Entity getEntity() { - return entity; - } - + @Override public String getEntityModelId() { - return entity == null ? null : entity.getModelId(getId()).orElse(null); - } - - public String getParentTaskId() { - return parentTaskId; - } - - public Integer getEstimatedMinutesLeft() { - return estimatedMinutesLeft; - } - - public User getUser() { - return user; - } - - public void setDetectionDateRange(DateRange detectionDateRange) { - this.detectionDateRange = detectionDateRange; - } - - public void setLatest(Boolean latest) { - isLatest = latest; - } - - public void setLastUpdateTime(Instant lastUpdateTime) { - this.lastUpdateTime = lastUpdateTime; + return entity == null ? null : entity.getModelId(getConfigId()).orElse(null); } } diff --git a/src/main/java/org/opensearch/ad/model/ADTaskType.java b/src/main/java/org/opensearch/ad/model/ADTaskType.java index b4e06aefc..fd5f97b37 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskType.java +++ b/src/main/java/org/opensearch/ad/model/ADTaskType.java @@ -12,37 +12,39 @@ package org.opensearch.ad.model; import java.util.List; -import java.util.stream.Collectors; + +import org.opensearch.timeseries.model.TaskType; import com.google.common.collect.ImmutableList; -public enum ADTaskType { +public enum ADTaskType implements TaskType { @Deprecated HISTORICAL, - REALTIME_SINGLE_ENTITY, - REALTIME_HC_DETECTOR, - HISTORICAL_SINGLE_ENTITY, + AD_REALTIME_SINGLE_STREAM, + AD_REALTIME_HC_DETECTOR, + AD_HISTORICAL_SINGLE_STREAM, // detector level task to track overall state, init progress, error etc. for HC detector - HISTORICAL_HC_DETECTOR, + AD_HISTORICAL_HC_DETECTOR, // entity level task to track just one specific entity's state, init progress, error etc. - HISTORICAL_HC_ENTITY; + AD_HISTORICAL_HC_ENTITY; public static List HISTORICAL_DETECTOR_TASK_TYPES = ImmutableList - .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL); + .of(ADTaskType.AD_HISTORICAL_HC_DETECTOR, ADTaskType.AD_REALTIME_SINGLE_STREAM, ADTaskType.HISTORICAL); public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList - .of(ADTaskType.HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL_SINGLE_ENTITY, ADTaskType.HISTORICAL_HC_ENTITY, ADTaskType.HISTORICAL); + .of( + ADTaskType.AD_HISTORICAL_HC_DETECTOR, + ADTaskType.AD_REALTIME_SINGLE_STREAM, + ADTaskType.AD_HISTORICAL_HC_ENTITY, + ADTaskType.HISTORICAL + ); public static List REALTIME_TASK_TYPES = ImmutableList - .of(ADTaskType.REALTIME_SINGLE_ENTITY, ADTaskType.REALTIME_HC_DETECTOR); + .of(ADTaskType.AD_REALTIME_SINGLE_STREAM, ADTaskType.AD_REALTIME_HC_DETECTOR); public static List ALL_DETECTOR_TASK_TYPES = ImmutableList .of( - ADTaskType.REALTIME_SINGLE_ENTITY, - ADTaskType.REALTIME_HC_DETECTOR, - ADTaskType.HISTORICAL_SINGLE_ENTITY, - ADTaskType.HISTORICAL_HC_DETECTOR, + ADTaskType.AD_REALTIME_SINGLE_STREAM, + ADTaskType.AD_REALTIME_HC_DETECTOR, + ADTaskType.AD_HISTORICAL_SINGLE_STREAM, + ADTaskType.AD_HISTORICAL_HC_DETECTOR, ADTaskType.HISTORICAL ); - - public static List taskTypeToString(List adTaskTypes) { - return adTaskTypes.stream().map(type -> type.name()).collect(Collectors.toList()); - } } diff --git a/src/main/java/org/opensearch/ad/model/DetectorProfile.java b/src/main/java/org/opensearch/ad/model/DetectorProfile.java index 77418552e..5ec2c5e51 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorProfile.java +++ b/src/main/java/org/opensearch/ad/model/DetectorProfile.java @@ -23,9 +23,10 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ConfigState; public class DetectorProfile implements Writeable, ToXContentObject, Mergeable { - private DetectorState state; + private ConfigState state; private String error; private ModelProfileOnNode[] modelProfile; private int shingleSize; @@ -43,7 +44,7 @@ public XContentBuilder toXContent(XContentBuilder builder) throws IOException { public DetectorProfile(StreamInput in) throws IOException { if (in.readBoolean()) { - this.state = in.readEnum(DetectorState.class); + this.state = in.readEnum(ConfigState.class); } this.error = in.readOptionalString(); @@ -65,7 +66,7 @@ public DetectorProfile(StreamInput in) throws IOException { private DetectorProfile() {} public static class Builder { - private DetectorState state = null; + private ConfigState state = null; private String error = null; private ModelProfileOnNode[] modelProfile = null; private int shingleSize = -1; @@ -79,7 +80,7 @@ public static class Builder { public Builder() {} - public Builder state(DetectorState state) { + public Builder state(ConfigState state) { this.state = state; return this; } @@ -227,11 +228,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return xContentBuilder.endObject(); } - public DetectorState getState() { + public ConfigState getState() { return state; } - public void setState(DetectorState state) { + public void setState(ConfigState state) { this.state = state; } diff --git a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java index 1e45bcc7a..2202bf215 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java +++ b/src/main/java/org/opensearch/ad/model/ModelProfileOnNode.java @@ -22,6 +22,7 @@ import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ModelProfile; public class ModelProfileOnNode implements Writeable, ToXContent { // field name in toXContent diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java similarity index 66% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java index 049b2d587..0199231a3 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointMaintainWorker.java @@ -16,34 +16,38 @@ import java.time.Clock; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; import java.util.Optional; import java.util.Random; +import java.util.function.Function; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; -public class CheckpointMaintainWorker extends ScheduledWorker { - private static final Logger LOG = LogManager.getLogger(CheckpointMaintainWorker.class); - public static final String WORKER_NAME = "checkpoint-maintain"; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; - private CheckPointMaintainRequestAdapter adapter; +public class ADCheckpointMaintainWorker extends + CheckpointMaintainWorker { + public static final String WORKER_NAME = "ad-checkpoint-maintain"; - public CheckpointMaintainWorker( + public ADCheckpointMaintainWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -51,10 +55,10 @@ public CheckpointMaintainWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointWriteWorker checkpointWriteQueue, + ADCheckpointWriteWorker checkpointWriteQueue, Duration stateTtl, NodeStateManager nodeStateManager, - CheckPointMaintainRequestAdapter adapter + Function> converter ) { super( WORKER_NAME, @@ -73,7 +77,9 @@ public CheckpointMaintainWorker( maintenanceFreqConstant, checkpointWriteQueue, stateTtl, - nodeStateManager + nodeStateManager, + converter, + AnalysisType.AD ); this.batchSize = AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); @@ -87,18 +93,5 @@ public CheckpointMaintainWorker( AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it ); - this.adapter = adapter; - } - - @Override - protected List transformRequests(List requests) { - List allRequests = new ArrayList<>(); - for (CheckpointMaintainRequest request : requests) { - Optional converted = adapter.convert(request); - if (!converted.isEmpty()) { - allRequests.add(converted.get()); - } - } - return allRequests; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java new file mode 100644 index 000000000..907d04641 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointReadWorker.java @@ -0,0 +1,169 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: + * a). If a checkpoint is not found, we forward that request to the cold start queue. + * b). When a request gets errors, the queue does not change its expiry time and puts + * that request to the end of the queue and automatically retries them before they expire. + * c) When a checkpoint is found, we load that point to memory and score the input + * data point and save the result if a complete model exists. Otherwise, we enqueue + * the sample. If we can host that model in memory (e.g., there is enough memory), + * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the + * updated checkpoint back to disk. + * + */ +public class ADCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "ad-checkpoint-read"; + + public ADCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADModelManager modelManager, + ADCheckpointDao checkpointDao, + ADColdStartWorker entityColdStartQueue, + ADResultWriteWorker resultWriteQueue, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ADCheckpointWriteWorker checkpointWriteQueue, + Stats adStats + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + resultWriteQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + adStats, + AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ADCommonName.CHECKPOINT_INDEX_NAME, + StatNames.AD_MODEL_CORRUTPION_COUNT, + AnalysisType.AD + ); + } + + @Override + protected void saveResult( + ThresholdingResult result, + Config config, + FeatureRequest origRequest, + Optional entity, + String modelId + ) { + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(origRequest.getCurrentFeature(), config), + entity, + indexUtil.getSchemaVersion(ADIndex.RESULT), + modelId, + null, + null + ); + + for (AnomalyResult r : indexableResults) { + resultWriteWorker + .put( + new ADResultWriteRequest( + origRequest.getExpirationEpochMs(), + config.getId(), + result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, + r, + config.getCustomResultIndex() + ) + ); + } + ; + } + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java new file mode 100644 index 000000000..c121ba9bf --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADCheckpointWriteWorker.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ADCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "ad-checkpoint-write"; + + public ADCheckpointWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager adNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + adNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.AD + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java similarity index 61% rename from src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java rename to src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java index fb834e089..2cf271ac7 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ColdEntityWorker.java +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdEntityWorker.java @@ -12,21 +12,32 @@ package org.opensearch.ad.ratelimit; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; import java.time.Clock; import java.time.Duration; -import java.util.List; import java.util.Random; -import java.util.stream.Collectors; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; /** * A queue slowly releasing low-priority requests to CheckpointReadQueue @@ -43,16 +54,17 @@ * entity requests.  * */ -public class ColdEntityWorker extends ScheduledWorker { - public static final String WORKER_NAME = "cold-entity"; +public class ADColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "ad-cold-entity"; - public ColdEntityWorker( + public ADColdEntityWorker( long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -60,7 +72,7 @@ public ColdEntityWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, - CheckpointReadWorker checkpointReadQueue, + ADCheckpointReadWorker checkpointReadQueue, Duration stateTtl, NodeStateManager nodeStateManager ) { @@ -81,25 +93,10 @@ public ColdEntityWorker( maintenanceFreqConstant, checkpointReadQueue, stateTtl, - nodeStateManager + nodeStateManager, + AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.AD ); - - this.batchSize = AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, it -> this.batchSize = it); - - this.expectedExecutionTimeInMilliSecsPerRequest = AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS - .get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer( - EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - it -> this.expectedExecutionTimeInMilliSecsPerRequest = it - ); - } - - @Override - protected List transformRequests(List requests) { - // guarantee we only send low priority requests - return requests.stream().filter(request -> request.priority == RequestPriority.LOW).collect(Collectors.toList()); } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java new file mode 100644 index 000000000..1c28744a8 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADColdStartWorker.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Random; + +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A queue for HCAD model training (a.k.a. cold start). As model training is a + * pretty expensive operation, we pull cold start requests from the queue in a + * serial fashion. Each detector has an equal chance of being pulled. The equal + * probability is achieved by putting model training requests for different + * detectors into different segments and pulling requests from segments in a + * round-robin fashion. + * + */ + +// suppress warning due to the use of generic type ADModelState +public class ADColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "ad-cold-start"; + + public ADColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADEntityColdStart entityColdStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ADPriorityCache cacheProvider + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + entityColdStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.AD + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest request, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + request.getEntity(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java new file mode 100644 index 000000000..912396ebd --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ADResultWriteRequest extends ResultWriteRequest { + + public ADResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority, result, resultIndex); + } + + public ADResultWriteRequest(StreamInput in) throws IOException { + super(in, AnomalyResult::new); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java new file mode 100644 index 000000000..1f4106bb5 --- /dev/null +++ b/src/main/java/org/opensearch/ad/ratelimit/ADResultWriteWorker.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.ratelimit; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ADResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "ad-result-write"; + + public ADResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ADIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + AD_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + AnomalyResult::parse, + AnalysisType.AD + ); + } + + @Override + protected ADResultBulkRequest toBatchRequest(List toProcess) { + final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); + for (ADResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ADResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + AnomalyResult result, + String resultIndex + ) { + return new ADResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java b/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java deleted file mode 100644 index 53d05ff11..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityColdStartWorker.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY; - -import java.time.Clock; -import java.time.Duration; -import java.util.ArrayDeque; -import java.util.Locale; -import java.util.Optional; -import java.util.Random; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.ActionListener; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Setting; -import org.opensearch.common.settings.Settings; -import org.opensearch.threadpool.ThreadPool; - -/** - * A queue for HCAD model training (a.k.a. cold start). As model training is a - * pretty expensive operation, we pull cold start requests from the queue in a - * serial fashion. Each detector has an equal chance of being pulled. The equal - * probability is achieved by putting model training requests for different - * detectors into different segments and pulling requests from segments in a - * round-robin fashion. - * - */ -public class EntityColdStartWorker extends SingleRequestWorker { - private static final Logger LOG = LogManager.getLogger(EntityColdStartWorker.class); - public static final String WORKER_NAME = "cold-start"; - - private final EntityColdStarter entityColdStarter; - private final CacheProvider cacheProvider; - - public EntityColdStartWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, - Setting maxHeapPercentForQueueSetting, - ClusterService clusterService, - Random random, - ADCircuitBreakerService adCircuitBreakerService, - ThreadPool threadPool, - Settings settings, - float maxQueuedTaskRatio, - Clock clock, - float mediumSegmentPruneRatio, - float lowSegmentPruneRatio, - int maintenanceFreqConstant, - Duration executionTtl, - EntityColdStarter entityColdStarter, - Duration stateTtl, - NodeStateManager nodeStateManager, - CacheProvider cacheProvider - ) { - super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, - maxHeapPercentForQueueSetting, - clusterService, - random, - adCircuitBreakerService, - threadPool, - settings, - maxQueuedTaskRatio, - clock, - mediumSegmentPruneRatio, - lowSegmentPruneRatio, - maintenanceFreqConstant, - ENTITY_COLD_START_QUEUE_CONCURRENCY, - executionTtl, - stateTtl, - nodeStateManager - ); - this.entityColdStarter = entityColdStarter; - this.cacheProvider = cacheProvider; - } - - @Override - protected void executeRequest(EntityRequest coldStartRequest, ActionListener listener) { - String detectorId = coldStartRequest.getId(); - - Optional modelId = coldStartRequest.getModelId(); - - if (false == modelId.isPresent()) { - String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); - LOG.warn(error); - listener.onFailure(new RuntimeException(error)); - return; - } - - ModelState modelState = new ModelState<>( - new EntityModel(coldStartRequest.getEntity(), new ArrayDeque<>(), null), - modelId.get(), - detectorId, - ModelType.ENTITY.getName(), - clock, - 0 - ); - - ActionListener coldStartListener = ActionListener.wrap(r -> { - nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { - try { - if (!detectorOptional.isPresent()) { - LOG - .error( - new ParameterizedMessage( - "fail to load trained model [{}] to cache due to the detector not being found.", - modelState.getModelId() - ) - ); - return; - } - AnomalyDetector detector = detectorOptional.get(); - EntityModel model = modelState.getModel(); - // load to cache if cold start succeeds - if (model != null && model.getTrcf() != null) { - cacheProvider.get().hostIfPossible(detector, modelState); - } - } finally { - listener.onResponse(null); - } - }, listener::onFailure)); - - }, e -> { - try { - if (ExceptionUtil.isOverloaded(e)) { - LOG.error("OpenSearch is overloaded"); - setCoolDownStart(); - } - nodeStateManager.setException(detectorId, e); - } finally { - listener.onFailure(e); - } - }); - - entityColdStarter.trainModel(coldStartRequest.getEntity(), detectorId, modelState, coldStartListener); - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java deleted file mode 100644 index 875974dbb..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityFeatureRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import org.opensearch.timeseries.model.Entity; - -public class EntityFeatureRequest extends EntityRequest { - private final double[] currentFeature; - private final long dataStartTimeMillis; - - public EntityFeatureRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - Entity entity, - double[] currentFeature, - long dataStartTimeMs - ) { - super(expirationEpochMs, detectorId, priority, entity); - this.currentFeature = currentFeature; - this.dataStartTimeMillis = dataStartTimeMs; - } - - public double[] getCurrentFeature() { - return currentFeature; - } - - public long getDataStartTimeMillis() { - return dataStartTimeMillis; - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java b/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java deleted file mode 100644 index 7acf2652a..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/EntityRequest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import java.util.Optional; - -import org.opensearch.timeseries.model.Entity; - -public class EntityRequest extends QueuedRequest { - private final Entity entity; - - /** - * - * @param expirationEpochMs Expiry time of the request - * @param detectorId Detector Id - * @param priority the entity's priority - * @param entity the entity's attributes - */ - public EntityRequest(long expirationEpochMs, String detectorId, RequestPriority priority, Entity entity) { - super(expirationEpochMs, detectorId, priority); - this.entity = entity; - } - - public Entity getEntity() { - return entity; - } - - public Optional getModelId() { - return entity.getModelId(detectorId); - } -} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java deleted file mode 100644 index a25bf3924..000000000 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteRequest.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.ratelimit; - -import java.io.IOException; - -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; - -public class ResultWriteRequest extends QueuedRequest implements Writeable { - private final AnomalyResult result; - // If resultIndex is null, result will be stored in default result index. - private final String resultIndex; - - public ResultWriteRequest( - long expirationEpochMs, - String detectorId, - RequestPriority priority, - AnomalyResult result, - String resultIndex - ) { - super(expirationEpochMs, detectorId, priority); - this.result = result; - this.resultIndex = resultIndex; - } - - public ResultWriteRequest(StreamInput in) throws IOException { - this.result = new AnomalyResult(in); - this.resultIndex = in.readOptionalString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - result.writeTo(out); - out.writeOptionalString(resultIndex); - } - - public AnomalyResult getResult() { - return result; - } - - public String getCustomResultIndex() { - return resultIndex; - } -} diff --git a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java index 331c3151f..4a10b3ad9 100644 --- a/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/AbstractAnomalyDetectorAction.java @@ -11,13 +11,14 @@ package org.opensearch.ad.rest; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_INTERVAL; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DETECTION_WINDOW_DELAY; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -31,25 +32,25 @@ public abstract class AbstractAnomalyDetectorAction extends BaseRestHandler { protected volatile Integer maxSingleEntityDetectors; protected volatile Integer maxMultiEntityDetectors; protected volatile Integer maxAnomalyFeatures; + protected volatile Integer maxCategoricalFields; public AbstractAnomalyDetectorAction(Settings settings, ClusterService clusterService) { - this.requestTimeout = REQUEST_TIMEOUT.get(settings); + this.requestTimeout = AD_REQUEST_TIMEOUT.get(settings); this.detectionInterval = DETECTION_INTERVAL.get(settings); this.detectionWindowDelay = DETECTION_WINDOW_DELAY.get(settings); - this.maxSingleEntityDetectors = MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); - this.maxMultiEntityDetectors = MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxSingleEntityDetectors = AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings); + this.maxMultiEntityDetectors = AD_MAX_HC_ANOMALY_DETECTORS.get(settings); this.maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); + this.maxCategoricalFields = ADNumericSetting.maxCategoricalFields(); // TODO: will add more cluster setting consumer later // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings - clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> requestTimeout = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_INTERVAL, it -> detectionInterval = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(DETECTION_WINDOW_DELAY, it -> detectionWindowDelay = it); clusterService .getClusterSettings() - .addSettingsUpdateConsumer(MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, it -> maxSingleEntityDetectors = it); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(MAX_MULTI_ENTITY_ANOMALY_DETECTORS, it -> maxMultiEntityDetectors = it); + .addSettingsUpdateConsumer(AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, it -> maxSingleEntityDetectors = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_HC_ANOMALY_DETECTORS, it -> maxMultiEntityDetectors = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); } } diff --git a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java index a5052c84d..14ef4c652 100644 --- a/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestAnomalyDetectorJobAction.java @@ -11,11 +11,8 @@ package org.opensearch.ad.rest; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; -import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -26,32 +23,30 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.seqno.SequenceNumbers; -import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; import com.google.common.collect.ImmutableList; /** * This class consists of the REST handler to handle request to start/stop AD job. */ -public class RestAnomalyDetectorJobAction extends BaseRestHandler { +public class RestAnomalyDetectorJobAction extends RestJobAction { public static final String AD_JOB_ACTION = "anomaly_detector_job_action"; private volatile TimeValue requestTimeout; public RestAnomalyDetectorJobAction(Settings settings, ClusterService clusterService) { - this.requestTimeout = REQUEST_TIMEOUT.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + this.requestTimeout = AD_REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> requestTimeout = it); } @Override @@ -66,40 +61,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } String detectorId = request.param(DETECTOR_ID); - long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); - long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); boolean historical = request.paramAsBoolean("historical", false); String rawPath = request.rawPath(); - DateRange detectionDateRange = parseDetectionDateRange(request); + DateRange detectionDateRange = parseInputDateRange(request); - AnomalyDetectorJobRequest anomalyDetectorJobRequest = new AnomalyDetectorJobRequest( - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath - ); + JobRequest anomalyDetectorJobRequest = new JobRequest(detectorId, detectionDateRange, historical, rawPath); return channel -> client .execute(AnomalyDetectorJobAction.INSTANCE, anomalyDetectorJobRequest, new RestToXContentListener<>(channel)); } - private DateRange parseDetectionDateRange(RestRequest request) throws IOException { - if (!request.hasContent()) { - return null; - } - XContentParser parser = request.contentParser(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - DateRange dateRange = DateRange.parse(parser); - return dateRange; - } - - @Override - public List routes() { - return ImmutableList.of(); - } - @Override public List replacedRoutes() { return ImmutableList diff --git a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java index b7a3aae6c..fc1859888 100644 --- a/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestDeleteAnomalyDetectorAction.java @@ -20,7 +20,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.rest.handler.AnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorRequest; @@ -29,6 +28,7 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.rest.handler.ConfigUpdateConfirmer; import com.google.common.collect.ImmutableList; @@ -40,7 +40,7 @@ public class RestDeleteAnomalyDetectorAction extends BaseRestHandler { public static final String DELETE_ANOMALY_DETECTOR_ACTION = "delete_anomaly_detector"; private static final Logger logger = LogManager.getLogger(RestDeleteAnomalyDetectorAction.class); - private final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); + private final ConfigUpdateConfirmer handler = new ConfigUpdateConfirmer(); public RestDeleteAnomalyDetectorAction() {} diff --git a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java index fe0d10ec9..13bdb5009 100644 --- a/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestExecuteAnomalyDetectorAction.java @@ -11,7 +11,7 @@ package org.opensearch.ad.rest; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.util.RestHandlerUtils.DETECTOR_ID; import static org.opensearch.timeseries.util.RestHandlerUtils.RUN; @@ -54,8 +54,8 @@ public class RestExecuteAnomalyDetectorAction extends BaseRestHandler { private final Logger logger = LogManager.getLogger(RestExecuteAnomalyDetectorAction.class); public RestExecuteAnomalyDetectorAction(Settings settings, ClusterService clusterService) { - this.requestTimeout = REQUEST_TIMEOUT.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(REQUEST_TIMEOUT, it -> requestTimeout = it); + this.requestTimeout = AD_REQUEST_TIMEOUT.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_REQUEST_TIMEOUT, it -> requestTimeout = it); } @Override diff --git a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java index d14ff85ce..0991364c2 100644 --- a/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestGetAnomalyDetectorAction.java @@ -18,24 +18,20 @@ import java.io.IOException; import java.util.List; import java.util.Locale; -import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.transport.GetAnomalyDetectorAction; -import org.opensearch.ad.transport.GetAnomalyDetectorRequest; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestActions; import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -66,7 +62,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli boolean returnJob = request.paramAsBoolean("job", false); boolean returnTask = request.paramAsBoolean("task", false); boolean all = request.paramAsBoolean("_all", false); - GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest( detectorId, RestActions.parseVersion(request), returnJob, @@ -74,7 +70,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli typesStr, rawPath, all, - buildEntity(request, detectorId) + RestHandlerUtils.buildEntity(request, detectorId) ); return channel -> client @@ -138,35 +134,4 @@ public List replacedRoutes() { ) ); } - - private Entity buildEntity(RestRequest request, String detectorId) throws IOException { - if (Strings.isEmpty(detectorId)) { - throw new IllegalStateException(ADCommonMessages.AD_ID_MISSING_MSG); - } - - String entityName = request.param(ADCommonName.CATEGORICAL_FIELD); - String entityValue = request.param(CommonName.ENTITY_KEY); - - if (entityName != null && entityValue != null) { - // single-stream profile request: - // GET _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= - return Entity.createSingleAttributeEntity(entityName, entityValue); - } else if (request.hasContent()) { - /* HCAD profile request: - * GET _plugins/_anomaly_detection/detectors//_profile/init_progress - * { - * "entity": [{ - * "name": "clientip", - * "value": "13.24.0.0" - * }] - * } - */ - Optional entity = Entity.fromJsonObject(request.contentParser()); - if (entity.isPresent()) { - return entity.get(); - } - } - // not a valid profile request with correct entity information - return null; - } } diff --git a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java index 6231d8e11..66981d54c 100644 --- a/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestIndexAnomalyDetectorAction.java @@ -94,7 +94,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli requestTimeout, maxSingleEntityDetectors, maxMultiEntityDetectors, - maxAnomalyFeatures + maxAnomalyFeatures, + maxCategoricalFields ); return channel -> client diff --git a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java index 65b936e98..a53e36931 100644 --- a/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestStatsAnomalyDetectorAction.java @@ -22,38 +22,36 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.transport.ADStatsRequest; import org.opensearch.ad.transport.StatsAnomalyDetectorAction; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.common.Strings; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.stats.Stats; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.ImmutableList; /** - * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from the anomaly detector plugin. + * RestStatsAnomalyDetectorAction consists of the REST handler to get the stats from the time series analytics plugin. */ public class RestStatsAnomalyDetectorAction extends BaseRestHandler { private static final String STATS_ANOMALY_DETECTOR_ACTION = "stats_anomaly_detector"; - private ADStats adStats; - private ClusterService clusterService; + private Stats timeSeriesStats; private DiscoveryNodeFilterer nodeFilter; /** * Constructor * - * @param adStats ADStats object + * @param timeSeriesStats TimeSeriesStats object * @param nodeFilter util class to get eligible data nodes */ - public RestStatsAnomalyDetectorAction(ADStats adStats, DiscoveryNodeFilterer nodeFilter) { - this.adStats = adStats; + public RestStatsAnomalyDetectorAction(Stats timeSeriesStats, DiscoveryNodeFilterer nodeFilter) { + this.timeSeriesStats = timeSeriesStats; this.nodeFilter = nodeFilter; } @@ -80,7 +78,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli private ADStatsRequest getRequest(RestRequest request) { // parse the nodes the user wants to query the stats for String nodesIdsStr = request.param("nodeId"); - Set validStats = adStats.getStats().keySet(); + Set validStats = timeSeriesStats.getStats().keySet(); ADStatsRequest adStatsRequest = null; if (!Strings.isEmpty(nodesIdsStr)) { diff --git a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java index e728889f8..b764f1b62 100644 --- a/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/rest/RestValidateAnomalyDetectorAction.java @@ -22,7 +22,6 @@ import java.util.List; import java.util.Locale; import java.util.Set; -import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; import org.opensearch.ad.constant.ADCommonMessages; @@ -44,7 +43,7 @@ import org.opensearch.rest.action.RestToXContentListener; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; import com.google.common.collect.ImmutableList; @@ -54,12 +53,6 @@ public class RestValidateAnomalyDetectorAction extends AbstractAnomalyDetectorAction { private static final String VALIDATE_ANOMALY_DETECTOR_ACTION = "validate_anomaly_detector_action"; - public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays - .asList(ValidationAspect.values()) - .stream() - .map(aspect -> aspect.getName()) - .collect(Collectors.toSet()); - public RestValidateAnomalyDetectorAction(Settings settings, ClusterService clusterService) { super(settings, clusterService); } @@ -98,7 +91,7 @@ protected void sendAnomalyDetectorValidationParseResponse(DetectorValidationIssu private Boolean validationTypesAreAccepted(String validationType) { Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return (!Collections.disjoint(typesInRequest, ALL_VALIDATION_ASPECTS_STRS)); + return (!Collections.disjoint(typesInRequest, AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS)); } @Override @@ -141,7 +134,8 @@ protected BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request maxSingleEntityDetectors, maxMultiEntityDetectors, maxAnomalyFeatures, - requestTimeout + requestTimeout, + maxCategoricalFields ); client.execute(ValidateAnomalyDetectorAction.INSTANCE, validateAnomalyDetectorRequest, new RestToXContentListener<>(channel)); }; diff --git a/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java new file mode 100644 index 000000000..4f199b0a7 --- /dev/null +++ b/src/main/java/org/opensearch/ad/rest/handler/ADIndexJobActionHandler.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.rest.handler; + +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.transport.AnomalyResultAction; +import org.opensearch.ad.transport.AnomalyResultRequest; +import org.opensearch.ad.transport.StopDetectorAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ADIndexJobActionHandler extends + IndexJobActionHandler { + + public ADIndexJobActionHandler( + Client client, + ADIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ADTaskManager adTaskManager, + ExecuteADResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + AnomalyResultAction.INSTANCE, + AnalysisType.AD, + DETECTION_STATE_INDEX, + StopDetectorAction.INSTANCE, + nodeStateManager, + settings, + AD_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new AnomalyResultRequest(configID, start, end); + } + + @Override + protected List getHistorialConfigTaskTypes() { + return HISTORICAL_DETECTOR_TASK_TYPES; + } +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java index 82f07b497..76773fbbc 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/AbstractAnomalyDetectorActionHandler.java @@ -22,70 +22,47 @@ import java.io.IOException; import java.time.Clock; import java.time.Instant; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionResponse; -import org.opensearch.action.admin.indices.create.CreateIndexResponse; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; -import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; -import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.support.replication.ReplicationResponse; -import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.MergeableList; -import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; -import org.opensearch.ad.settings.ADNumericSetting; -import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.authuser.User; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.rest.RestRequest; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.common.exception.ValidationException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; import com.google.common.collect.Sets; @@ -115,47 +92,21 @@ * instantiate the ModelValidationActionHandler class and run the non-blocker validation logic

* */ -public abstract class AbstractAnomalyDetectorActionHandler { - public static final String EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG = "Can't create more than %d multi-entity anomaly detectors."; - public static final String EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG = - "Can't create more than %d single-entity anomaly detectors."; +public abstract class AbstractAnomalyDetectorActionHandler extends + AbstractTimeSeriesActionHandler { + protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); + + public static final String EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG = "Can't create more than %d HC anomaly detectors."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG = + "Can't create more than %d single-stream anomaly detectors."; public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create anomaly detector as no document is found in the indices: "; - public static final String ONLY_ONE_CATEGORICAL_FIELD_ERR_MSG = "We can have only one categorical field."; - public static final String CATEGORICAL_FIELD_TYPE_ERR_MSG = "A categorical field must be of type keyword or ip."; - public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; public static final String DUPLICATE_DETECTOR_MSG = "Cannot create anomaly detector with name [%s] as it's already used by detector %s"; - public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; - public static final Integer MAX_DETECTOR_NAME_SIZE = 64; - private static final Set DEFAULT_VALIDATION_ASPECTS = Sets.newHashSet(ValidationAspect.DETECTOR); - - public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + MAX_DETECTOR_NAME_SIZE + " characters."; - - protected final ADIndexManagement anomalyDetectionIndices; - protected final String detectorId; - protected final Long seqNo; - protected final Long primaryTerm; - protected final WriteRequest.RefreshPolicy refreshPolicy; - protected final AnomalyDetector anomalyDetector; - protected final ClusterService clusterService; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of detector %s"; - protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); - protected final TimeValue requestTimeout; - protected final Integer maxSingleEntityAnomalyDetectors; - protected final Integer maxMultiEntityAnomalyDetectors; - protected final Integer maxAnomalyFeatures; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); - protected final RestRequest.Method method; - protected final Client client; - protected final SecurityClientUtil clientUtil; - protected final TransportService transportService; - protected final NamedXContentRegistry xContentRegistry; - protected final ActionListener listener; - protected final User user; - protected final ADTaskManager adTaskManager; - protected final SearchFeatureDao searchFeatureDao; - protected final boolean isDryRun; + protected final Integer maxSingleStreamDetectors; + protected final Integer maxHCAnomalyDetectors; + protected final TaskManager adTaskManager; protected final Clock clock; - protected final String validationType; protected final Settings settings; /** @@ -165,7 +116,6 @@ public abstract class AbstractAnomalyDetectorActionHandler listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, @@ -199,748 +149,164 @@ public AbstractAnomalyDetectorActionHandler( WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamAnomalyDetectors, + Integer maxHCAnomalyDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, - ADTaskManager adTaskManager, + TaskManager adTaskManager, SearchFeatureDao searchFeatureDao, String validationType, boolean isDryRun, Clock clock, Settings settings ) { - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; - this.transportService = transportService; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.listener = listener; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.refreshPolicy = refreshPolicy; - this.anomalyDetector = anomalyDetector; - this.requestTimeout = requestTimeout; - this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; - this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; - this.maxAnomalyFeatures = maxAnomalyFeatures; - this.method = method; - this.xContentRegistry = xContentRegistry; - this.user = user; + super( + anomalyDetector, + anomalyDetectionIndices, + isDryRun, + client, + detectorId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.AD + ); + this.maxSingleStreamDetectors = maxSingleStreamAnomalyDetectors; + this.maxHCAnomalyDetectors = maxHCAnomalyDetectors; this.adTaskManager = adTaskManager; - this.searchFeatureDao = searchFeatureDao; - this.validationType = validationType; - this.isDryRun = isDryRun; this.clock = clock; this.settings = settings; } - /** - * Start function to process create/update/validate anomaly detector request. - * If detector is not using custom result index, check if anomaly detector - * index exist first, if not, will create first. Otherwise, check if custom - * result index exists or not. If exists, will check if index mapping matches - * AD result index mapping and if user has correct permission to write index. - * If doesn't exist, will create custom result index with AD result index - * mapping. - */ - public void start() { - String resultIndex = anomalyDetector.getCustomResultIndex(); - // use default detector result index which is system index - if (resultIndex == null) { - createOrUpdateDetector(); - return; - } - - if (this.isDryRun) { - if (anomalyDetectionIndices.doesIndexExist(resultIndex)) { - anomalyDetectionIndices - .validateCustomResultIndexAndExecute( - resultIndex, - () -> createOrUpdateDetector(), - ActionListener.wrap(r -> createOrUpdateDetector(), ex -> { - logger.error(ex); - listener - .onFailure( - new ValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX, ValidationAspect.DETECTOR) - ); - return; - }) - ); - return; - } else { - createOrUpdateDetector(); - return; - } - } - // use custom result index if not validating and resultIndex not null - anomalyDetectionIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateDetector(), listener); - } - - // if isDryRun is true then this method is being executed through Validation API meaning actual - // index won't be created, only validation checks will be executed throughout the class - private void createOrUpdateDetector() { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (!anomalyDetectionIndices.doesConfigIndexExist() && !this.isDryRun) { - logger.info("AnomalyDetector Indices do not exist"); - anomalyDetectionIndices - .initConfigIndex( - ActionListener - .wrap(response -> onCreateMappingsResponse(response, false), exception -> listener.onFailure(exception)) - ); - } else { - logger.info("AnomalyDetector Indices do exist, calling prepareAnomalyDetectorIndexing"); - logger.info("DryRun variable " + this.isDryRun); - validateDetectorName(this.isDryRun); - } - } catch (Exception e) { - logger.error("Failed to create or update detector " + detectorId, e); - listener.onFailure(e); - } + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.DETECTOR); } - // These validation checks are executed here and not in AnomalyDetector.parse() - // in order to not break any past detectors that were made with invalid names - // because it was never check on the backend in the past - protected void validateDetectorName(boolean indexingDryRun) { - if (!anomalyDetector.getName().matches(NAME_REGEX)) { - listener.onFailure(new ValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - - } - if (anomalyDetector.getName().length() > MAX_DETECTOR_NAME_SIZE) { - listener.onFailure(new ValidationException(INVALID_NAME_SIZE, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - return; - } - validateTimeField(indexingDryRun); + @Override + protected AnomalyDetector parse(XContentParser parser, GetResponse response) throws IOException { + return AnomalyDetector.parse(parser, response.getId(), response.getVersion()); } - protected void validateTimeField(boolean indexingDryRun) { - String givenTimeField = anomalyDetector.getTimeField(); - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(givenTimeField); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - // comments explaining fieldMappingResponse parsing can be found inside following method: - // AbstractAnomalyDetectorActionHandler.validateCategoricalField(String, boolean) - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - boolean foundField = false; - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.DATE_TYPE)) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - } - } - } - if (!foundField) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), - ValidationIssueType.TIMEFIELD_FIELD, - ValidationAspect.DETECTOR - ) - ); - return; - } - prepareAnomalyDetectorIndexing(indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); - clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); - } - - /** - * Prepare for indexing a new anomaly detector. - * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false - */ - protected void prepareAnomalyDetectorIndexing(boolean indexingDryRun) { - if (method == RestRequest.Method.PUT) { - handler - .getDetectorJob( - clusterService, - client, - detectorId, - listener, - () -> updateAnomalyDetector(detectorId, indexingDryRun), - xContentRegistry - ); - } else { - createAnomalyDetector(indexingDryRun); - } - } - - protected void updateAnomalyDetector(String detectorId, boolean indexingDryRun) { - GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client - .get( - request, - ActionListener - .wrap( - response -> onGetAnomalyDetectorResponse(response, indexingDryRun, detectorId), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorResponse(GetResponse response, boolean indexingDryRun, String detectorId) { - if (!response.isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector existingDetector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - // If detector category field changed, frontend may not be able to render AD result for different detector types correctly. - // For example, if detector changed from HC to single entity detector, AD result page may show multiple anomaly - // result points on the same time point if there are multiple entities have anomaly results. - // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type - // in top N entities list. That's confusing. - // So we decide to block updating detector category field. - if (!listEqualsWithoutConsideringOrder(existingDetector.getCategoryFields(), anomalyDetector.getCategoryFields())) { - listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); - return; - } - if (!Objects.equals(existingDetector.getCustomResultIndex(), anomalyDetector.getCustomResultIndex())) { - listener - .onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST)); - return; - } - - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - // can't update detector if there is AD task running - listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); - } else { - validateExistingDetector(existingDetector, indexingDryRun); - } - }, transportService, true, listener); - } catch (IOException e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - - } - - protected void validateExistingDetector(AnomalyDetector existingDetector, boolean indexingDryRun) { - if (!hasCategoryField(existingDetector) && hasCategoryField(this.anomalyDetector)) { - validateAgainstExistingMultiEntityAnomalyDetector(detectorId, indexingDryRun); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - } - - protected boolean hasCategoryField(AnomalyDetector detector) { - return detector.getCategoryFields() != null && !detector.getCategoryFields().isEmpty(); - } - - protected void validateAgainstExistingMultiEntityAnomalyDetector(String detectorId, boolean indexingDryRun) { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(AnomalyDetector.CATEGORY_FIELD)); - - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchMultiEntityAdResponse(response, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } - - } - - protected void createAnomalyDetector(boolean indexingDryRun) { - try { - List categoricalFields = anomalyDetector.getCategoryFields(); - if (categoricalFields != null && categoricalFields.size() > 0) { - validateAgainstExistingMultiEntityAnomalyDetector(null, indexingDryRun); + @Override + protected void confirmHistoricalRunning(String id, ActionListener listener) { + adTaskManager.getAndExecuteOnLatestConfigLevelTask(id, HISTORICAL_DETECTOR_TASK_TYPES, (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + // can't update detector if there is AD task running + listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); } else { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - QueryBuilder query = QueryBuilders.matchAllQuery(); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - - client - .search( - searchRequest, - ActionListener - .wrap( - response -> onSearchSingleEntityAdResponse(response, indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - searchAdInputIndices(null, indexingDryRun); - } - + listener.onResponse(null); } - } catch (Exception e) { - listener.onFailure(e); - } + }, transportService, true, listener); } - protected void onSearchSingleEntityAdResponse(SearchResponse response, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxSingleEntityAnomalyDetectors) { - String errorMsgSingleEntity = String - .format(Locale.ROOT, EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors); - logger.error(errorMsgSingleEntity); - if (indexingDryRun) { - listener - .onFailure( - new ValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR) - ); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); - } else { - searchAdInputIndices(null, indexingDryRun); - } + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, getMaxSingleStreamConfigs()); } - protected void onSearchMultiEntityAdResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value >= maxMultiEntityAnomalyDetectors) { - String errorMsg = String.format(Locale.ROOT, EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateCategoricalField(detectorId, indexingDryRun); - } + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, getMaxHCConfigs()); } - @SuppressWarnings("unchecked") - protected void validateCategoricalField(String detectorId, boolean indexingDryRun) { - List categoryField = anomalyDetector.getCategoryFields(); - - if (categoryField == null) { - searchAdInputIndices(detectorId, indexingDryRun); - return; - } - - // we only support a certain number of categorical field - // If there is more fields than required, AnomalyDetector's constructor - // throws ADValidationException before reaching this line - int maxCategoryFields = ADNumericSetting.maxCategoricalFields(); - if (categoryField.size() > maxCategoryFields) { - listener - .onFailure( - new ValidationException( - CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - String categoryField0 = categoryField.get(0); - - GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); - getMappingsRequest.indices(anomalyDetector.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); - getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); - - ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { - // example getMappingsResponse: - // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', - // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} - // for nested field, it would be - // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} - boolean foundField = false; - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - Map> mappingsByIndex = getMappingsResponse.mappings(); - - for (Map mappingsByField : mappingsByIndex.values()) { - for (Map.Entry field2Metadata : mappingsByField.entrySet()) { - // example output: - // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', - // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} - - // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata - - GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); - - if (fieldMetadata != null) { - // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field - Map fieldMap = fieldMetadata.sourceAsMap(); - if (fieldMap != null) { - for (Object type : fieldMap.values()) { - if (type != null && type instanceof Map) { - foundField = true; - Map metadataMap = (Map) type; - String typeName = (String) metadataMap.get(CommonName.TYPE); - if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { - listener - .onFailure( - new ValidationException( - CATEGORICAL_FIELD_TYPE_ERR_MSG, - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - } - } - } - - } - } - } - - if (foundField == false) { - listener - .onFailure( - new ValidationException( - String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), - ValidationIssueType.CATEGORY, - ValidationAspect.DETECTOR - ) - ); - return; - } - - searchAdInputIndices(detectorId, indexingDryRun); - }, error -> { - String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", anomalyDetector.getIndices()); - logger.error(message, error); - listener.onFailure(new IllegalArgumentException(message)); - }); - - clientUtil.executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, mappingsListener); + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); } - protected void searchAdInputIndices(String detectorId, boolean indexingDryRun) { - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() - .query(QueryBuilders.matchAllQuery()) - .size(0) - .timeout(requestTimeout); - - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0])).source(searchSourceBuilder); - - ActionListener searchResponseListener = ActionListener - .wrap( - searchResponse -> onSearchAdInputIndicesResponse(searchResponse, detectorId, indexingDryRun), - exception -> listener.onFailure(exception) - ); - - clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); + @Override + protected String getDuplicateConfigErrorMsg(String name, List otherConfigIds) { + return String.format(Locale.ROOT, DUPLICATE_DETECTOR_MSG, name, otherConfigIds); } - protected void onSearchAdInputIndicesResponse(SearchResponse response, String detectorId, boolean indexingDryRun) throws IOException { - if (response.getHits().getTotalHits().value == 0) { - String errorMsg = NO_DOCS_IN_USER_INDEX_MSG + Arrays.toString(anomalyDetector.getIndices().toArray(new String[0])); - logger.error(errorMsg); - if (indexingDryRun) { - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.INDICES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new IllegalArgumentException(errorMsg)); - } else { - validateAnomalyDetectorFeatures(detectorId, indexingDryRun); - } + @Override + protected AnomalyDetector copyConfig(User user, Config config) { + return new AnomalyDetector( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), + Instant.now(), + config.getCategoryFields(), + user, + config.getCustomResultIndex(), + config.getImputationOption() + ); } - protected void checkADNameExists(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetectionIndices.doesConfigIndexExist()) { - BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - // src/main/resources/mappings/anomaly-detectors.json#L14 - boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", anomalyDetector.getName())); - if (StringUtils.isNotBlank(detectorId)) { - boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, detectorId)); - } - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); - SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); - client - .search( - searchRequest, - ActionListener - .wrap( - searchResponse -> onSearchADNameResponse(searchResponse, detectorId, anomalyDetector.getName(), indexingDryRun), - exception -> listener.onFailure(exception) - ) - ); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } - + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexAnomalyDetectorResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (AnomalyDetector) config, + RestStatus.CREATED + ); } - protected void onSearchADNameResponse(SearchResponse response, String detectorId, String name, boolean indexingDryRun) - throws IOException { - if (response.getHits().getTotalHits().value > 0) { - String errorMsg = String - .format( - Locale.ROOT, - DUPLICATE_DETECTOR_MSG, - name, - Arrays.stream(response.getHits().getHits()).map(hit -> hit.getId()).collect(Collectors.toList()) - ); - logger.warn(errorMsg); - listener.onFailure(new ValidationException(errorMsg, ValidationIssueType.NAME, ValidationAspect.DETECTOR)); - } else { - tryIndexingAnomalyDetector(indexingDryRun); - } + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.DETECTOR); } - protected void tryIndexingAnomalyDetector(boolean indexingDryRun) throws IOException { - if (!indexingDryRun) { - indexAnomalyDetector(detectorId); - } else { - finishDetectorValidationOrContinueToModelValidation(); - } + @Override + protected Integer getMaxSingleStreamConfigs() { + return maxSingleStreamDetectors; } - protected Set getValidationTypes(String validationType) { - if (StringUtils.isBlank(validationType)) { - return DEFAULT_VALIDATION_ASPECTS; - } else { - Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); - return ValidationAspect - .getNames(Sets.intersection(RestValidateAnomalyDetectorAction.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); - } + @Override + protected Integer getMaxHCConfigs() { + return maxHCAnomalyDetectors; } - protected void finishDetectorValidationOrContinueToModelValidation() { - logger.info("Skipping indexing detector. No blocking issue found so far."); - if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { - listener.onResponse(null); - } else { - ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( - clusterService, - client, - clientUtil, - (ActionListener) listener, - anomalyDetector, - requestTimeout, - xContentRegistry, - searchFeatureDao, - validationType, - clock, - settings, - user - ); - modelValidationActionHandler.checkIfMultiEntityDetector(); - } + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); } - @SuppressWarnings("unchecked") - protected void indexAnomalyDetector(String detectorId) throws IOException { - AnomalyDetector detector = new AnomalyDetector( - anomalyDetector.getId(), - anomalyDetector.getVersion(), - anomalyDetector.getName(), - anomalyDetector.getDescription(), - anomalyDetector.getTimeField(), - anomalyDetector.getIndices(), - anomalyDetector.getFeatureAttributes(), - anomalyDetector.getFilterQuery(), - anomalyDetector.getInterval(), - anomalyDetector.getWindowDelay(), - anomalyDetector.getShingleSize(), - anomalyDetector.getUiMetadata(), - anomalyDetector.getSchemaVersion(), - Instant.now(), - anomalyDetector.getCategoryFields(), - user, - anomalyDetector.getCustomResultIndex(), - anomalyDetector.getImputationOption() + @Override + protected void validateModel(ActionListener listener) { + ModelValidationActionHandler modelValidationActionHandler = new ModelValidationActionHandler( + clusterService, + client, + clientUtil, + (ActionListener) listener, + (AnomalyDetector) config, + requestTimeout, + xContentRegistry, + searchFeatureDao, + validationType, + clock, + settings, + user ); - IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) - .setRefreshPolicy(refreshPolicy) - .source(detector.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout); - if (StringUtils.isNotBlank(detectorId)) { - indexRequest.id(detectorId); - } - - client.index(indexRequest, new ActionListener() { - @Override - public void onResponse(IndexResponse indexResponse) { - String errorMsg = checkShardsFailure(indexResponse); - if (errorMsg != null) { - listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); - return; - } - listener - .onResponse( - (T) new IndexAnomalyDetectorResponse( - indexResponse.getId(), - indexResponse.getVersion(), - indexResponse.getSeqNo(), - indexResponse.getPrimaryTerm(), - detector, - RestStatus.CREATED - ) - ); - } - - @Override - public void onFailure(Exception e) { - logger.warn("Failed to update detector", e); - if (e.getMessage() != null && e.getMessage().contains("version conflict")) { - listener - .onFailure( - new IllegalArgumentException("There was a problem updating the historical detector:[" + detectorId + "]") - ); - } else { - listener.onFailure(e); - } - } - }); - } - - protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun) throws IOException { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - prepareAnomalyDetectorIndexing(indexingDryRun); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - listener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } - } - - protected String checkShardsFailure(IndexResponse response) { - StringBuilder failureReasons = new StringBuilder(); - if (response.getShardInfo().getFailed() > 0) { - for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { - failureReasons.append(failure); - } - return failureReasons.toString(); - } - return null; - } - - /** - * Validate config/syntax, and runtime error of detector features - * @param detectorId detector id - * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector - * @throws IOException when fail to parse feature aggregation - */ - // TODO: move this method to util class so that it can be re-usable for more use cases - // https://github.com/opensearch-project/anomaly-detection/issues/39 - protected void validateAnomalyDetectorFeatures(String detectorId, boolean indexingDryRun) throws IOException { - if (anomalyDetector != null - && (anomalyDetector.getFeatureAttributes() == null || anomalyDetector.getFeatureAttributes().isEmpty())) { - checkADNameExists(detectorId, indexingDryRun); - return; - } - // checking configuration/syntax error of detector features - String error = RestHandlerUtils.checkFeaturesSyntax(anomalyDetector, maxAnomalyFeatures); - if (StringUtils.isNotBlank(error)) { - if (indexingDryRun) { - listener.onFailure(new ValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES, ValidationAspect.DETECTOR)); - return; - } - listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); - return; - } - // checking runtime error from feature query - ActionListener>> validateFeatureQueriesListener = ActionListener - .wrap( - response -> { checkADNameExists(detectorId, indexingDryRun); }, - exception -> { - listener - .onFailure( - new ValidationException( - exception.getMessage(), - ValidationIssueType.FEATURE_ATTRIBUTES, - ValidationAspect.DETECTOR - ) - ); - } - ); - MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = - new MultiResponsesDelegateActionListener>>( - validateFeatureQueriesListener, - anomalyDetector.getFeatureAttributes().size(), - String.format(Locale.ROOT, "Validation failed for feature(s) of detector %s", anomalyDetector.getName()), - false - ); - - for (Feature feature : anomalyDetector.getFeatureAttributes()) { - SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); - AggregatorFactories.Builder internalAgg = parseAggregators( - feature.getAggregation().toString(), - xContentRegistry, - feature.getId() - ); - ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); - SearchRequest searchRequest = new SearchRequest().indices(anomalyDetector.getIndices().toArray(new String[0])).source(ssb); - ActionListener searchResponseListener = ActionListener.wrap(response -> { - Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); - if (aggFeatureResult.isPresent()) { - multiFeatureQueriesResponseListener - .onResponse( - new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) - ); - } else { - String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); - logger.error(errorMessage); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - } - }, e -> { - String errorMessage; - if (isExceptionCausedByInvalidQuery(e)) { - errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); - } else { - errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); - } - logger.error(errorMessage, e); - multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); - }); - clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, searchResponseListener); - } + modelValidationActionHandler.checkIfMultiEntityDetector(); } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java index b401ce007..51e3df820 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorActionHandler.java @@ -11,14 +11,14 @@ package org.opensearch.ad.rest.handler; -import org.opensearch.action.ActionListener; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -26,6 +26,9 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; /** @@ -50,9 +53,10 @@ public class IndexAnomalyDetectorActionHandler extends AbstractAnomalyDetectorAc * @param refreshPolicy refresh policy * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration - * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed - * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed - * @param maxAnomalyFeatures max features allowed per detector + * @param maxSingleStreamDetectors max single-stream anomaly detectors allowed + * @param maxHCDetectors max HC detectors allowed + * @param maxFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -65,7 +69,6 @@ public IndexAnomalyDetectorActionHandler( Client client, SecurityClientUtil clientUtil, TransportService transportService, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, String detectorId, Long seqNo, @@ -73,13 +76,14 @@ public IndexAnomalyDetectorActionHandler( WriteRequest.RefreshPolicy refreshPolicy, AnomalyDetector anomalyDetector, TimeValue requestTimeout, - Integer maxSingleEntityAnomalyDetectors, - Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures, + Integer maxSingleStreamDetectors, + Integer maxHCDetectors, + Integer maxFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, - ADTaskManager adTaskManager, + TaskManager adTaskManager, SearchFeatureDao searchFeatureDao, Settings settings ) { @@ -88,7 +92,6 @@ public IndexAnomalyDetectorActionHandler( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -96,9 +99,10 @@ public IndexAnomalyDetectorActionHandler( refreshPolicy, anomalyDetector, requestTimeout, - maxSingleEntityAnomalyDetectors, - maxMultiEntityAnomalyDetectors, - maxAnomalyFeatures, + maxSingleStreamDetectors, + maxHCDetectors, + maxFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -110,12 +114,4 @@ public IndexAnomalyDetectorActionHandler( settings ); } - - /** - * Start function to process create/update anomaly detector request. - */ - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java deleted file mode 100644 index 824c6fc21..000000000 --- a/src/main/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandler.java +++ /dev/null @@ -1,434 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.rest.handler; - -import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.action.DocWriteResponse.Result.UPDATED; -import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; - -import java.io.IOException; -import java.time.Duration; -import java.time.Instant; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; -import org.opensearch.action.ActionListener; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADTaskState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.StopDetectorAction; -import org.opensearch.ad.transport.StopDetectorRequest; -import org.opensearch.ad.transport.StopDetectorResponse; -import org.opensearch.client.Client; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; -import org.opensearch.jobscheduler.spi.schedule.Schedule; -import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.function.ExecutorFunction; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.util.RestHandlerUtils; -import org.opensearch.transport.TransportService; - -import com.google.common.base.Throwables; - -/** - * Anomaly detector job REST action handler to process POST/PUT request. - */ -public class IndexAnomalyDetectorJobActionHandler { - - private final ADIndexManagement anomalyDetectionIndices; - private final String detectorId; - private final Long seqNo; - private final Long primaryTerm; - private final Client client; - private final NamedXContentRegistry xContentRegistry; - private final TransportService transportService; - private final ADTaskManager adTaskManager; - - private final Logger logger = LogManager.getLogger(IndexAnomalyDetectorJobActionHandler.class); - private final TimeValue requestTimeout; - private final ExecuteADResultResponseRecorder recorder; - - /** - * Constructor function. - * - * @param client ES node client that executes actions on the local node - * @param anomalyDetectionIndices anomaly detector index manager - * @param detectorId detector identifier - * @param seqNo sequence number of last modification - * @param primaryTerm primary term of last modification - * @param requestTimeout request time out configuration - * @param xContentRegistry Registry which is used for XContentParser - * @param transportService transport service - * @param adTaskManager AD task manager - * @param recorder Utility to record AnomalyResultAction execution result - */ - public IndexAnomalyDetectorJobActionHandler( - Client client, - ADIndexManagement anomalyDetectionIndices, - String detectorId, - Long seqNo, - Long primaryTerm, - TimeValue requestTimeout, - NamedXContentRegistry xContentRegistry, - TransportService transportService, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - this.client = client; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.detectorId = detectorId; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.requestTimeout = requestTimeout; - this.xContentRegistry = xContentRegistry; - this.transportService = transportService; - this.adTaskManager = adTaskManager; - this.recorder = recorder; - } - - /** - * Start anomaly detector job. - * 1. If job doesn't exist, create new job. - * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. - * @param detector anomaly detector - * @param listener Listener to send responses - */ - public void startAnomalyDetectorJob(AnomalyDetector detector, ActionListener listener) { - // this start listener is created & injected throughout the job handler so that whenever the job response is received, - // there's the extra step of trying to index results and update detector state with a 60s delay. - ActionListener startListener = ActionListener.wrap(r -> { - try { - Instant executionEndTime = Instant.now(); - IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) detector.getInterval(); - Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); - AnomalyResultRequest getRequest = new AnomalyResultRequest( - detector.getId(), - executionStartTime.toEpochMilli(), - executionEndTime.toEpochMilli() - ); - client - .execute( - AnomalyResultAction.INSTANCE, - getRequest, - ActionListener - .wrap( - response -> recorder.indexAnomalyResult(executionStartTime, executionEndTime, response, detector), - exception -> { - - recorder - .indexAnomalyResultException( - executionStartTime, - executionEndTime, - Throwables.getStackTraceAsString(exception), - null, - detector - ); - } - ) - ); - } catch (Exception ex) { - listener.onFailure(ex); - return; - } - listener.onResponse(r); - - }, listener::onFailure); - if (!anomalyDetectionIndices.doesJobIndexExist()) { - anomalyDetectionIndices.initJobIndex(ActionListener.wrap(response -> { - if (response.isAcknowledged()) { - logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); - createJob(detector, startListener); - } else { - logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); - startListener - .onFailure( - new OpenSearchStatusException( - "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - } - }, exception -> startListener.onFailure(exception))); - } else { - createJob(detector, startListener); - } - } - - private void createJob(AnomalyDetector detector, ActionListener listener) { - try { - IntervalTimeConfiguration interval = (IntervalTimeConfiguration) detector.getInterval(); - Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); - Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); - - AnomalyDetectorJob job = new AnomalyDetectorJob( - detector.getId(), - schedule, - detector.getWindowDelay(), - true, - Instant.now(), - null, - Instant.now(), - duration.getSeconds(), - detector.getUser(), - detector.getCustomResultIndex() - ); - - getAnomalyDetectorJobForWrite(detector, job, listener); - } catch (Exception e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - - private void getAnomalyDetectorJobForWrite( - AnomalyDetector detector, - AnomalyDetectorJob job, - ActionListener listener - ) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client - .get( - getRequest, - ActionListener - .wrap( - response -> onGetAnomalyDetectorJobForWrite(response, detector, job, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onGetAnomalyDetectorJobForWrite( - GetResponse response, - AnomalyDetector detector, - AnomalyDetectorJob job, - ActionListener listener - ) throws IOException { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob currentAdJob = AnomalyDetectorJob.parse(parser); - if (currentAdJob.isEnabled()) { - listener - .onFailure(new OpenSearchStatusException("Anomaly detector job is already running: " + detectorId, RestStatus.OK)); - return; - } else { - AnomalyDetectorJob newJob = new AnomalyDetectorJob( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - job.isEnabled(), - Instant.now(), - currentAdJob.getDisabledTime(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - // Get latest realtime task and check its state before index job. Will reset running realtime task - // as STOPPED first if job disabled, then start new job and create new realtime task. - adTaskManager - .startDetector( - detector, - null, - job.getUser(), - transportService, - ActionListener - .wrap( - r -> { indexAnomalyDetectorJob(newJob, null, listener); }, - e -> { - // Have logged error message in ADTaskManager#startDetector - listener.onFailure(e); - } - ) - ); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + job.getName(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - adTaskManager - .startDetector( - detector, - null, - job.getUser(), - transportService, - ActionListener.wrap(r -> { indexAnomalyDetectorJob(job, null, listener); }, e -> listener.onFailure(e)) - ); - } - } - - private void indexAnomalyDetectorJob( - AnomalyDetectorJob job, - ExecutorFunction function, - ActionListener listener - ) throws IOException { - IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setIfSeqNo(seqNo) - .setIfPrimaryTerm(primaryTerm) - .timeout(requestTimeout) - .id(detectorId); - client - .index( - indexRequest, - ActionListener - .wrap( - response -> onIndexAnomalyDetectorJobResponse(response, function, listener), - exception -> listener.onFailure(exception) - ) - ); - } - - private void onIndexAnomalyDetectorJobResponse( - IndexResponse response, - ExecutorFunction function, - ActionListener listener - ) { - if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { - String errorMsg = getShardsFailure(response); - listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); - return; - } - if (function != null) { - function.execute(); - } else { - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); - listener.onResponse(anomalyDetectorJobResponse); - } - } - - /** - * Stop anomaly detector job. - * 1.If job not exists, return error message - * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. - * - * @param detectorId detector identifier - * @param listener Listener to send responses - */ - public void stopAnomalyDetectorJob(String detectorId, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - - client.get(getRequest, ActionListener.wrap(response -> { - if (response.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); - if (!job.isEnabled()) { - adTaskManager.stopLatestRealtimeTask(detectorId, ADTaskState.STOPPED, null, transportService, listener); - } else { - AnomalyDetectorJob newJob = new AnomalyDetectorJob( - job.getName(), - job.getSchedule(), - job.getWindowDelay(), - false, - job.getEnabledTime(), - Instant.now(), - Instant.now(), - job.getLockDurationSeconds(), - job.getUser(), - job.getCustomResultIndex() - ); - indexAnomalyDetectorJob( - newJob, - () -> client - .execute( - StopDetectorAction.INSTANCE, - new StopDetectorRequest(detectorId), - stopAdDetectorListener(detectorId, listener) - ), - listener - ); - } - } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job not exist: " + detectorId, RestStatus.BAD_REQUEST)); - } - }, exception -> listener.onFailure(exception))); - } - - private ActionListener stopAdDetectorListener( - String detectorId, - ActionListener listener - ) { - return new ActionListener() { - @Override - public void onResponse(StopDetectorResponse stopDetectorResponse) { - if (stopDetectorResponse.success()) { - logger.info("AD model deleted successfully for detector {}", detectorId); - // StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. - // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. - adTaskManager.stopLatestRealtimeTask(detectorId, ADTaskState.STOPPED, null, null, listener); - } else { - logger.error("Failed to delete AD model for detector {}", detectorId); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - ADTaskState.FAILED, - new OpenSearchStatusException("Failed to delete AD model", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - } - - @Override - public void onFailure(Exception e) { - logger.error("Failed to delete AD model for detector " + detectorId, e); - // If failed to clear all realtime cache, will try to re-clear coordinating node cache. - adTaskManager - .stopLatestRealtimeTask( - detectorId, - ADTaskState.FAILED, - new OpenSearchStatusException("Failed to execute stop detector action", RestStatus.INTERNAL_SERVER_ERROR), - transportService, - listener - ); - } - }; - } - -} diff --git a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java index ada684808..1ffb271ff 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ModelValidationActionHandler.java @@ -35,13 +35,9 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.MergeableList; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -68,15 +64,21 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.MergeableList; import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.ConfigUpdateConfirmer; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; /** *

This class executes all validation checks that are not blocking on the 'model' level. @@ -94,7 +96,7 @@ public class ModelValidationActionHandler { protected final ClusterService clusterService; protected final Logger logger = LogManager.getLogger(AbstractAnomalyDetectorActionHandler.class); protected final TimeValue requestTimeout; - protected final AnomalyDetectorActionHandler handler = new AnomalyDetectorActionHandler(); + protected final ConfigUpdateConfirmer handler = new ConfigUpdateConfirmer(); protected final Client client; protected final SecurityClientUtil clientUtil; protected final NamedXContentRegistry xContentRegistry; @@ -104,6 +106,7 @@ public class ModelValidationActionHandler { protected final String validationType; protected final Settings settings; protected final User user; + protected final AnalysisType context; /** * Constructor function. @@ -147,6 +150,7 @@ public ModelValidationActionHandler( this.clock = clock; this.settings = settings; this.user = user; + this.context = AnalysisType.AD; } // Need to first check if multi entity detector or not before doing any sort of validation. @@ -253,6 +257,7 @@ private void getTopEntity(ActionListener> topEntityListener) client::search, user, client, + context, searchResponseListener ); } @@ -344,6 +349,7 @@ private void getBucketAggregates( client::search, user, client, + context, searchResponseListener ); } @@ -461,6 +467,7 @@ public void onResponse(SearchResponse response) { client::search, user, client, + context, this ); // In this case decreasingInterval has to be true already, so we will stop @@ -495,6 +502,7 @@ private void searchWithDifferentInterval(long newIntervalMinuteValue) { client::search, user, client, + context, this ); } @@ -571,6 +579,7 @@ private void checkRawDataSparsity(long latestTime) { client::search, user, client, + context, searchResponseListener ); } @@ -631,6 +640,7 @@ private void checkDataFilterSparsity(long latestTime) { client::search, user, client, + context, searchResponseListener ); } @@ -693,6 +703,7 @@ private void checkCategoryFieldSparsity(Map topEntity, long late client::search, user, client, + context, searchResponseListener ); } @@ -783,6 +794,7 @@ private void checkFeatureQueryDelegate(long latestTime) throws IOException { client::search, user, client, + context, searchResponseListener ); } diff --git a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java index 163d1df63..56874615b 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/ad/rest/handler/ValidateAnomalyDetectorActionHandler.java @@ -13,12 +13,9 @@ import java.time.Clock; -import org.opensearch.action.ActionListener; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -26,6 +23,9 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.util.SecurityClientUtil; /** * Anomaly detector REST action handler to process POST request. @@ -39,13 +39,13 @@ public class ValidateAnomalyDetectorActionHandler extends AbstractAnomalyDetecto * @param clusterService ClusterService * @param client ES node client that executes actions on the local node * @param clientUtil AD client utility - * @param listener ES channel used to construct bytes / builder based outputs, and send responses * @param anomalyDetectionIndices anomaly detector index manager * @param anomalyDetector anomaly detector instance * @param requestTimeout request time out configuration * @param maxSingleEntityAnomalyDetectors max single-entity anomaly detectors allowed * @param maxMultiEntityAnomalyDetectors max multi-entity detectors allowed * @param maxAnomalyFeatures max features allowed per detector + * @param maxCategoricalFields max number of categorical fields * @param method Rest Method type * @param xContentRegistry Registry which is used for XContentParser * @param user User context @@ -58,13 +58,13 @@ public ValidateAnomalyDetectorActionHandler( ClusterService clusterService, Client client, SecurityClientUtil clientUtil, - ActionListener listener, ADIndexManagement anomalyDetectionIndices, AnomalyDetector anomalyDetector, TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, Integer maxAnomalyFeatures, + Integer maxCategoricalFields, RestRequest.Method method, NamedXContentRegistry xContentRegistry, User user, @@ -78,9 +78,8 @@ public ValidateAnomalyDetectorActionHandler( client, clientUtil, null, - listener, anomalyDetectionIndices, - AnomalyDetector.NO_ID, + Config.NO_ID, null, null, null, @@ -89,6 +88,7 @@ public ValidateAnomalyDetectorActionHandler( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, user, @@ -100,16 +100,4 @@ public ValidateAnomalyDetectorActionHandler( settings ); } - - // If validation type is detector then all validation in AbstractAnomalyDetectorActionHandler that is called - // by super.start() involves validation checks against the detector configurations, - // any issues raised here would block user from creating the anomaly detector. - // If validation Aspect is of type model then further non-blocker validation will be executed - // after the blocker validation is executed. Any issues that are raised for model validation - // are simply warnings for the user in terms of how configuration could be changed to lead to - // a higher likelihood of model training completing successfully - @Override - public void start() { - super.start(); - } } diff --git a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java index ed4414f6c..172cf0248 100644 --- a/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADEnabledSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; diff --git a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java index e064867a0..869cdf412 100644 --- a/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java +++ b/src/main/java/org/opensearch/ad/settings/ADNumericSetting.java @@ -1,12 +1,6 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ package org.opensearch.ad.settings; diff --git a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java index 22e72eba0..e36e41e19 100644 --- a/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java +++ b/src/main/java/org/opensearch/ad/settings/AnomalyDetectorSettings.java @@ -11,8 +11,6 @@ package org.opensearch.ad.settings; -import java.time.Duration; - import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.TimeValue; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -25,7 +23,7 @@ public final class AnomalyDetectorSettings { private AnomalyDetectorSettings() {} public static final int MAX_DETECTOR_UPPER_LIMIT = 10000; - public static final Setting MAX_SINGLE_ENTITY_ANOMALY_DETECTORS = Setting + public static final Setting AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS = Setting .intSetting( "plugins.anomaly_detection.max_anomaly_detectors", LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, @@ -35,7 +33,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting MAX_MULTI_ENTITY_ANOMALY_DETECTORS = Setting + public static final Setting AD_MAX_HC_ANOMALY_DETECTORS = Setting .intSetting( "plugins.anomaly_detection.max_multi_entity_anomaly_detectors", LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, @@ -55,7 +53,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting REQUEST_TIMEOUT = Setting + public static final Setting AD_REQUEST_TIMEOUT = Setting .positiveTimeSetting( "plugins.anomaly_detection.request_timeout", LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT, @@ -114,7 +112,13 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting + /** + * @deprecated This setting is deprecated because we need to manage fault tolerance for + * multiple analysis such as AD and forecasting. + * Use {@link #TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE} instead. + */ + @Deprecated + public static final Setting AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE = Setting .intSetting( "plugins.anomaly_detection.max_retry_for_unresponsive_node", LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, @@ -123,7 +127,13 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting COOLDOWN_MINUTES = Setting + /** + * @deprecated This setting is deprecated because we need to manage fault tolerance for + * multiple analysis such as AD and forecasting. + * Use {@link #TimeSeriesSettings.COOLDOWN_MINUTES} instead. + */ + @Deprecated + public static final Setting AD_COOLDOWN_MINUTES = Setting .positiveTimeSetting( "plugins.anomaly_detection.cooldown_minutes", LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES, @@ -131,7 +141,13 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting BACKOFF_MINUTES = Setting + /** + * @deprecated This setting is deprecated because we need to manage fault tolerance for + * multiple analysis such as AD and forecasting. + * Use {@link #TimeSeriesSettings.BACKOFF_MINUTES} instead. + */ + @Deprecated + public static final Setting AD_BACKOFF_MINUTES = Setting .positiveTimeSetting( "plugins.anomaly_detection.backoff_minutes", LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, @@ -156,7 +172,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting + public static final Setting AD_MAX_RETRY_FOR_END_RUN_EXCEPTION = Setting .intSetting( "plugins.anomaly_detection.max_retry_for_end_run_exception", LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION, @@ -165,7 +181,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting FILTER_BY_BACKEND_ROLES = Setting + public static final Setting AD_FILTER_BY_BACKEND_ROLES = Setting .boolSetting( "plugins.anomaly_detection.filter_by_backend_roles", LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, @@ -177,14 +193,12 @@ private AnomalyDetectorSettings() {} public static final String ANOMALY_DETECTION_STATE_INDEX_MAPPING_FILE = "mappings/anomaly-detection-state.json"; public static final String CHECKPOINT_INDEX_MAPPING_FILE = "mappings/anomaly-checkpoint.json"; - public static final Duration HOURLY_MAINTENANCE = Duration.ofHours(1); - // saving checkpoint every 12 hours. // To support 1 million entities in 36 data nodes, each node has roughly 28K models. // In each hour, we roughly need to save 2400 models. Since each model saving can - // take about 1 seconds (default value of AnomalyDetectorSettings.EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_SECS) + // take about 1 seconds (default value of AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS) // we can use up to 2400 seconds to finish saving checkpoints. - public static final Setting CHECKPOINT_SAVING_FREQ = Setting + public static final Setting AD_CHECKPOINT_SAVING_FREQ = Setting .positiveTimeSetting( "plugins.anomaly_detection.checkpoint_saving_freq", TimeValue.timeValueHours(12), @@ -192,7 +206,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting CHECKPOINT_TTL = Setting + public static final Setting AD_CHECKPOINT_TTL = Setting .positiveTimeSetting( "plugins.anomaly_detection.checkpoint_ttl", TimeValue.timeValueDays(7), @@ -203,52 +217,16 @@ private AnomalyDetectorSettings() {} // ====================================== // ML parameters // ====================================== - // RCF - public static final int NUM_SAMPLES_PER_TREE = 256; - - public static final int NUM_TREES = 30; - - public static final int TRAINING_SAMPLE_INTERVAL = 64; - - public static final double TIME_DECAY = 0.0001; - - // If we have 32 + shingleSize (hopefully recent) values, RCF can get up and running. It will be noisy — - // there is a reason that default size is 256 (+ shingle size), but it may be more useful for people to - /// start seeing some results. - public static final int NUM_MIN_SAMPLES = 32; - - // The threshold for splitting RCF models in single-stream detectors. - // The smallest machine in the Amazon managed service has 1GB heap. - // With the setting, the desired model size there is of 2 MB. - // By default, we can have at most 5 features. Since the default shingle size - // is 8, we have at most 40 dimensions in RCF. In our current RCF setting, - // 30 trees, and bounding box cache ratio 0, 40 dimensions use 449KB. - // Users can increase the number of features to 10 and shingle size to 60, - // 30 trees, bounding box cache ratio 0, 600 dimensions use 1.8 MB. - // Since these sizes are smaller than the threshold 2 MB, we won't split models - // even in the smallest machine. - public static final double DESIRED_MODEL_SIZE_PERCENTAGE = 0.002; - - public static final Setting MODEL_MAX_SIZE_PERCENTAGE = Setting + public static final Setting AD_MODEL_MAX_SIZE_PERCENTAGE = Setting .doubleSetting( "plugins.anomaly_detection.model_max_size_percent", LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, 0, - 0.7, + 0.9, Setting.Property.NodeScope, Setting.Property.Dynamic ); - // for a batch operation, we want all of the bounding box in-place for speed - public static final double BATCH_BOUNDING_BOX_CACHE_RATIO = 1; - - // Thresholding - public static final double THRESHOLD_MIN_PVALUE = 0.995; - - public static final double THRESHOLD_MAX_RANK_ERROR = 0.0001; - - public static final double THRESHOLD_MAX_SCORE = 8; - public static final int THRESHOLD_NUM_LOGNORMAL_QUANTILES = 400; public static final int THRESHOLD_DOWNSAMPLES = 5_000; @@ -269,9 +247,6 @@ private AnomalyDetectorSettings() {} // shingling public static final double MAX_SHINGLE_PROPORTION_MISSING = 0.25; - // AD JOB - public static final long DEFAULT_AD_JOB_LOC_DURATION_SECONDS = 60; - // Thread pool public static final int AD_THEAD_POOL_QUEUE_SIZE = 1000; @@ -293,7 +268,7 @@ private AnomalyDetectorSettings() {} * Other detectors cannot use space reserved by a detector's dedicated cache. * DEDICATED_CACHE_SIZE is a setting to make dedicated cache's size flexible. * When that setting is changed, if the size decreases, we will release memory - * if required (e.g., when a user also decreased AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, + * if required (e.g., when a user also decreased AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, * the max memory percentage that AD can use); * if the size increases, we may reject the setting change if we cannot fulfill * that request (e.g., when it will uses more memory than allowed for AD). @@ -305,25 +280,13 @@ private AnomalyDetectorSettings() {} * where 3.2 GB is from 10% memory limit of AD plugin. * That's why I am using 60_000 as the max limit. */ - public static final Setting DEDICATED_CACHE_SIZE = Setting + public static final Setting AD_DEDICATED_CACHE_SIZE = Setting .intSetting("plugins.anomaly_detection.dedicated_cache_size", 10, 0, 60_000, Setting.Property.NodeScope, Setting.Property.Dynamic); // We only keep priority (4 bytes float) in inactive cache. 1 million priorities // take up 4 MB. public static final int MAX_INACTIVE_ENTITIES = 1_000_000; - // Increase the value will adding pressure to indexing anomaly results and our feature query - // OpenSearch-only setting as previous the legacy default is too low (1000) - public static final Setting MAX_ENTITIES_PER_QUERY = Setting - .intSetting( - "plugins.anomaly_detection.max_entities_per_query", - 1_000_000, - 0, - 2_000_000, - Setting.Property.NodeScope, - Setting.Property.Dynamic - ); - // save partial zero-anomaly grade results after indexing pressure reaching the limit // Opendistro version has similar setting. I lowered the value to make room // for INDEX_PRESSURE_HARD_LIMIT. I don't find a floatSetting that has both default @@ -361,12 +324,6 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - // max entity value's length - public static int MAX_ENTITY_LENGTH = 256; - - // number of bulk checkpoints per second - public static double CHECKPOINT_BULK_PER_SECOND = 0.02; - // ====================================== // Historical analysis // ====================================== @@ -382,6 +339,8 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); + // Use TimeSeriesSettings.MAX_CACHED_DELETED_TASKS for both AD and forecasting + @Deprecated // Maximum number of deleted tasks can keep in cache. public static final Setting MAX_CACHED_DELETED_TASKS = Setting .intSetting( @@ -455,7 +414,7 @@ private AnomalyDetectorSettings() {} // ====================================== // the percentage of heap usage allowed for queues holding small requests // set it to 0 to disable the queue - public static final Setting COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.cold_entity_queue_max_heap_percent", 0.001f, @@ -464,7 +423,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.checkpoint_read_queue_max_heap_percent", 0.001f, @@ -473,7 +432,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.entity_cold_start_queue_max_heap_percent", 0.001f, @@ -484,7 +443,7 @@ private AnomalyDetectorSettings() {} // the percentage of heap usage allowed for queues holding large requests // set it to 0 to disable the queue - public static final Setting CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.checkpoint_write_queue_max_heap_percent", 0.01f, @@ -493,7 +452,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.result_write_queue_max_heap_percent", 0.01f, @@ -502,7 +461,7 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Setting CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT = Setting + public static final Setting AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT = Setting .floatSetting( "plugins.anomaly_detection.checkpoint_maintain_queue_max_heap_percent", 0.001f, @@ -514,7 +473,7 @@ private AnomalyDetectorSettings() {} // expected execution time per cold entity request. This setting controls // the speed of cold entity requests execution. The larger, the faster, and // the more performance impact to customers' workload. - public static final Setting EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS = Setting + public static final Setting AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS = Setting .intSetting( "plugins.anomaly_detection.expected_cold_entity_execution_time_in_millisecs", 3000, @@ -537,73 +496,10 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - /** - * EntityRequest has entityName (# category fields * 256, the recommended limit - * of a keyword field length), model Id (roughly 256 bytes), and QueuedRequest - * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, - * 8 bytes), and priority (12 bytes). - * Plus Java object size (12 bytes), we have roughly 928 bytes per request - * assuming we have 2 categorical fields (plan to support 2 categorical fields now). - * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 928 = heap / 928,000. - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 928 = 1078 - */ - public static int ENTITY_REQUEST_SIZE_IN_BYTES = 928; - - /** - * EntityFeatureRequest consists of EntityRequest (928 bytes, read comments - * of ENTITY_COLD_START_QUEUE_SIZE_CONSTANT), pointer to current feature - * (8 bytes), and dataStartTimeMillis (8 bytes). We have roughly - * 928 + 16 = 944 bytes per request. - * - * We don't want the total size exceeds 0.1% of the heap. - * We should have at most 0.1% heap / 944 = heap / 944,000 - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 944 = 1059 - */ - public static int ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES = 944; - - /** - * ResultWriteRequest consists of index request (roughly 1KB), and QueuedRequest - * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). - * Plus Java object size (12 bytes), we have roughly 1160 bytes per request - * - * We don't want the total size exceeds 1% of the heap. - * We should have at most 1% heap / 1148 = heap / 116,000 - * For t3.small, 1% heap is of 10MB. The queue's size is up to - * 10^ 7 / 1160 = 8621 - */ - public static int RESULT_WRITE_QUEUE_SIZE_IN_BYTES = 1160; - - /** - * CheckpointWriteRequest consists of IndexRequest (200 KB), and QueuedRequest - * fields (148 bytes, read comments of ENTITY_REQUEST_SIZE_CONSTANT). - * The total is roughly 200 KB per request. - * - * We don't want the total size exceeds 1% of the heap. - * We should have at most 1% heap / 200KB = heap / 20,000,000 - * For t3.small, 1% heap is of 10MB. The queue's size is up to - * 10^ 7 / 2.0 * 10^5 = 50 - */ - public static int CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES = 200_000; - - /** - * CheckpointMaintainRequest has model Id (roughly 256 bytes), and QueuedRequest - * fields including detector Id(roughly 128 bytes), expirationEpochMs (long, - * 8 bytes), and priority (12 bytes). - * Plus Java object size (12 bytes), we have roughly 416 bytes per request. - * We don't want the total size exceeds 0.1% of the heap. - * We can have at most 0.1% heap / 416 = heap / 416,000. - * For t3.small, 0.1% heap is of 1MB. The queue's size is up to - * 10^ 6 / 416 = 2403 - */ - public static int CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES = 416; - /** * Max concurrent entity cold starts per node */ - public static final Setting ENTITY_COLD_START_QUEUE_CONCURRENCY = Setting + public static final Setting AD_ENTITY_COLD_START_QUEUE_CONCURRENCY = Setting .intSetting( "plugins.anomaly_detection.entity_cold_start_queue_concurrency", 1, @@ -696,48 +592,24 @@ private AnomalyDetectorSettings() {} Setting.Property.Dynamic ); - public static final Duration QUEUE_MAINTENANCE = Duration.ofMinutes(10); - - public static final float MAX_QUEUED_TASKS_RATIO = 0.5f; - - public static final float MEDIUM_SEGMENT_PRUNE_RATIO = 0.1f; - - public static final float LOW_SEGMENT_PRUNE_RATIO = 0.3f; - - // expensive maintenance (e.g., queue maintenance) with 1/10000 probability - public static final int MAINTENANCE_FREQ_CONSTANT = 10000; - - // ====================================== - // Checkpoint setting - // ====================================== - // we won't accept a checkpoint larger than 30MB. Or we risk OOM. - // For reference, in RCF 1.0, the checkpoint of a RCF with 50 trees, 10 dimensions, - // 256 samples is of 3.2MB. - // In compact rcf, the same RCF is of 163KB. - // Since we allow at most 5 features, and the default shingle size is 8 and default - // tree number size is 100, we can have at most 25.6 MB in RCF 1.0. - // It is possible that cx increases the max features or shingle size, but we don't want - // to risk OOM for the flexibility. - public static final int MAX_CHECKPOINT_BYTES = 30_000_000; - - // Sets the cap on the number of buffer that can be allocated by the rcf deserialization - // buffer pool. Each buffer is of 512 bytes. Memory occupied by 20 buffers is 10.24 KB. - public static final int MAX_TOTAL_RCF_SERIALIZATION_BUFFERS = 20; - - // the size of the buffer used for rcf deserialization - public static final int SERIALIZATION_BUFFER_BYTES = 512; - // ====================================== // pagination setting // ====================================== // pagination size - public static final Setting PAGE_SIZE = Setting + public static final Setting AD_PAGE_SIZE = Setting .intSetting("plugins.anomaly_detection.page_size", 1_000, 0, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - public static final float INTERVAL_RATIO_FOR_REQUESTS = 0.9f; + // Increase the value will adding pressure to indexing anomaly results and our feature query + // OpenSearch-only setting as previous the legacy default is too low (1000) + public static final Setting AD_MAX_ENTITIES_PER_QUERY = Setting + .intSetting( + "plugins.anomaly_detection.max_entities_per_query", + 1_000_000, + 0, + 2_000_000, + Setting.Property.NodeScope, + Setting.Property.Dynamic + ); // ====================================== // preview setting @@ -788,7 +660,7 @@ private AnomalyDetectorSettings() {} // ====================================== // the max number of models to return per node. // the setting is used to limit resource usage due to showing models - public static final Setting MAX_MODEL_SIZE_PER_NODE = Setting + public static final Setting AD_MAX_MODEL_SIZE_PER_NODE = Setting .intSetting( "plugins.anomaly_detection.max_model_size_per_node", 100, @@ -807,11 +679,6 @@ private AnomalyDetectorSettings() {} // total entities up to 10,000. public static final int MAX_TOTAL_ENTITIES_TO_TRACK = 10_000; - // ====================================== - // Cold start setting - // ====================================== - public static int MAX_COLD_START_ROUNDS = 2; - // ====================================== // Validate Detector API setting // ====================================== diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java deleted file mode 100644 index 3f5421032..000000000 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplier.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.stats.suppliers; - -import static org.opensearch.ad.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.LAST_USED_TIME_KEY; -import static org.opensearch.ad.ml.ModelState.MODEL_TYPE_KEY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.Supplier; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.timeseries.constant.CommonName; - -/** - * ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains - */ -public class ModelsOnNodeSupplier implements Supplier>> { - private ModelManager modelManager; - private CacheProvider cache; - // the max number of models to return per node. Defaults to 100. - private volatile int numModelsToReturn; - - /** - * Set that contains the model stats that should be exposed. - */ - public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( - Arrays - .asList( - CommonName.MODEL_ID_FIELD, - ADCommonName.DETECTOR_ID_KEY, - MODEL_TYPE_KEY, - CommonName.ENTITY_KEY, - LAST_USED_TIME_KEY, - LAST_CHECKPOINT_TIME_KEY - ) - ); - - /** - * Constructor - * - * @param modelManager object that manages the model partitions hosted on the node - * @param cache object that manages multi-entity detectors' models - * @param settings node settings accessor - * @param clusterService Cluster service accessor - */ - public ModelsOnNodeSupplier(ModelManager modelManager, CacheProvider cache, Settings settings, ClusterService clusterService) { - this.modelManager = modelManager; - this.cache = cache; - this.numModelsToReturn = MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); - } - - @Override - public List> get() { - List> values = new ArrayList<>(); - Stream - .concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()) - .limit(numModelsToReturn) - .forEach( - modelState -> values - .add( - modelState - .getModelStateAsMap() - .entrySet() - .stream() - .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) - ) - ); - - return values; - } -} diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java index 00c574669..d4bf9b763 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskCache.java @@ -11,11 +11,6 @@ package org.opensearch.ad.task; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.TIME_DECAY; - import java.util.ArrayDeque; import java.util.Deque; import java.util.Map; @@ -26,8 +21,8 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.amazon.randomcutforest.config.Precision; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -55,13 +50,13 @@ public class ADBatchTaskCache { private Entity entity; protected ADBatchTaskCache(ADTask adTask) { - this.detectorId = adTask.getId(); + this.detectorId = adTask.getConfigId(); this.taskId = adTask.getTaskId(); this.detectorTaskId = adTask.getDetectorLevelTaskId(); this.entity = adTask.getEntity(); AnomalyDetector detector = adTask.getDetector(); - int numberOfTrees = NUM_TREES; + int numberOfTrees = TimeSeriesSettings.NUM_TREES; int shingleSize = detector.getShingleSize(); this.shingle = new ArrayDeque<>(shingleSize); int dimensions = detector.getShingleSize() * detector.getEnabledFeatureIds().size(); @@ -70,16 +65,16 @@ protected ADBatchTaskCache(ADTask adTask) { .builder() .dimensions(dimensions) .numberOfTrees(numberOfTrees) - .timeDecay(TIME_DECAY) - .sampleSize(NUM_SAMPLES_PER_TREE) - .outputAfter(NUM_MIN_SAMPLES) - .initialAcceptFraction(NUM_MIN_SAMPLES * 1.0d / NUM_SAMPLES_PER_TREE) + .timeDecay(TimeSeriesSettings.TIME_DECAY) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) + .outputAfter(TimeSeriesSettings.NUM_MIN_SAMPLES) + .initialAcceptFraction(TimeSeriesSettings.NUM_MIN_SAMPLES * 1.0d / TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .parallelExecutionEnabled(false) .compact(true) .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .boundingBoxCacheFraction(TimeSeriesSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) .shingleSize(shingleSize) - .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + .anomalyRate(1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE) .build(); this.thresholdModelTrained = false; diff --git a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java index 2140ecf10..3dfe0bb71 100644 --- a/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java +++ b/src/main/java/org/opensearch/ad/task/ADBatchTaskRunner.java @@ -11,25 +11,17 @@ package org.opensearch.ad.task; -import static org.opensearch.ad.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; -import static org.opensearch.ad.model.ADTask.CURRENT_PIECE_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.WORKER_NODE_FIELD; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_FOR_HISTORICAL_ANALYSIS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_TOP_ENTITIES_LIMIT_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; -import static org.opensearch.ad.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; +import static org.opensearch.timeseries.stats.InternalStatNames.JVM_HEAP_USAGE; import static org.opensearch.timeseries.stats.StatNames.AD_EXECUTING_BATCH_TASK_COUNT; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; import java.time.Clock; import java.time.Instant; @@ -50,33 +42,22 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.PriorityTracker; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; -import org.opensearch.ad.feature.SinglePointFeatures; import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADBatchAnomalyResultResponse; import org.opensearch.ad.transport.ADBatchTaskRemoteExecutionAction; import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesAction; import org.opensearch.ad.transport.ADStatsRequest; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -95,19 +76,33 @@ import org.opensearch.search.aggregations.metrics.InternalMin; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.PriorityTracker; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; @@ -125,19 +120,19 @@ public class ADBatchTaskRunner { private final ThreadPool threadPool; private final Client client; private final SecurityClientUtil clientUtil; - private final ADStats adStats; + private final Stats adStats; private final ClusterService clusterService; private final FeatureManager featureManager; - private final ADCircuitBreakerService adCircuitBreakerService; + private final CircuitBreakerService adCircuitBreakerService; private final ADTaskManager adTaskManager; - private final AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler; + private final ResultBulkIndexingHandler anomalyResultBulkIndexHandler; private final ADIndexManagement anomalyDetectionIndices; private final SearchFeatureDao searchFeatureDao; private final ADTaskCacheManager adTaskCacheManager; private final TransportRequestOptions option; private final HashRing hashRing; - private final ModelManager modelManager; + private final ADModelManager modelManager; private volatile Integer maxAdBatchTaskPerNode; private volatile Integer pieceSize; @@ -147,6 +142,7 @@ public class ADBatchTaskRunner { private static final int MAX_TOP_ENTITY_SEARCH_BUCKETS = 1000; private static final int SLEEP_TIME_FOR_NEXT_ENTITY_TASK_IN_MILLIS = 2000; + private AnalysisType context; public ADBatchTaskRunner( Settings settings, @@ -154,16 +150,16 @@ public ADBatchTaskRunner( ClusterService clusterService, Client client, SecurityClientUtil clientUtil, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, FeatureManager featureManager, ADTaskManager adTaskManager, ADIndexManagement anomalyDetectionIndices, - ADStats adStats, - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler, + Stats adStats, + ResultBulkIndexingHandler anomalyResultBulkIndexHandler, ADTaskCacheManager adTaskCacheManager, SearchFeatureDao searchFeatureDao, HashRing hashRing, - ModelManager modelManager + ADModelManager modelManager ) { this.settings = settings; this.threadPool = threadPool; @@ -180,7 +176,7 @@ public ADBatchTaskRunner( this.option = TransportRequestOptions .builder() .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) .build(); this.adTaskCacheManager = adTaskCacheManager; @@ -219,7 +215,7 @@ public ADBatchTaskRunner( */ public void run(ADTask adTask, TransportService transportService, ActionListener listener) { boolean isHCDetector = adTask.getDetector().isHighCardinality(); - if (isHCDetector && !adTaskCacheManager.topEntityInited(adTask.getId())) { + if (isHCDetector && !adTaskCacheManager.topEntityInited(adTask.getConfigId())) { // Initialize top entities for HC detector threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { ActionListener hcDelegatedListener = getInternalHCDelegatedListener(adTask); @@ -261,12 +257,12 @@ private ActionListener getTopEntitiesListener( ActionListener listener ) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); ActionListener actionListener = ActionListener.wrap(response -> { adTaskCacheManager.setTopEntityInited(detectorId); int totalEntities = adTaskCacheManager.getPendingEntityCount(detectorId); logger.info("Total top entities: {} for detector {}, task {}", totalEntities, detectorId, taskId); - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + hashRing.getNodesWithSameLocalVersion(dataNodes -> { int numberOfEligibleDataNodes = dataNodes.length; // maxAdBatchTaskPerNode means how many task can run on per data node, which is hard limitation per node. // maxRunningEntitiesPerDetector means how many entities can run per detector on whole cluster, which is @@ -294,7 +290,7 @@ private ActionListener getTopEntitiesListener( }, listener); }, e -> { logger.debug("Failed to run task " + taskId, e); - if (adTask.getTaskType().equals(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { + if (adTask.getTaskType().equals(ADTaskType.AD_HISTORICAL_HC_DETECTOR.name())) { adTaskManager.entityTaskDone(adTask, e, transportService); } listener.onFailure(e); @@ -389,16 +385,16 @@ private void searchTopEntitiesForMultiCategoryHC( logger.debug("finish searching top entities at " + System.currentTimeMillis()); List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); if (topNEntities.size() == 0) { - logger.error("There is no entity found for detector " + adTask.getId()); - internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + logger.error("There is no entity found for detector " + adTask.getConfigId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "No entity found")); return; } - adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); - adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + adTaskCacheManager.addPendingEntities(adTask.getConfigId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getConfigId(), topNEntities.size()); internalHCListener.onResponse("Get top entities done"); } }, e -> { - logger.error("Failed to get top entities for detector " + adTask.getId(), e); + logger.error("Failed to get top entities for detector " + adTask.getConfigId(), e); internalHCListener.onFailure(e); }); int minimumDocCount = Math.max((int) (bucketInterval / adTask.getDetector().getIntervalInMilliseconds()) / 2, 1); @@ -466,16 +462,16 @@ private void searchTopEntitiesForSingleCategoryHC( logger.debug("finish searching top entities at " + System.currentTimeMillis()); List topNEntities = priorityTracker.getTopNEntities(maxTopEntitiesPerHcDetector); if (topNEntities.size() == 0) { - logger.error("There is no entity found for detector " + adTask.getId()); - internalHCListener.onFailure(new ResourceNotFoundException(adTask.getId(), "No entity found")); + logger.error("There is no entity found for detector " + adTask.getConfigId()); + internalHCListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "No entity found")); return; } - adTaskCacheManager.addPendingEntities(adTask.getId(), topNEntities); - adTaskCacheManager.setTopEntityCount(adTask.getId(), topNEntities.size()); + adTaskCacheManager.addPendingEntities(adTask.getConfigId(), topNEntities); + adTaskCacheManager.setTopEntityCount(adTask.getConfigId(), topNEntities.size()); internalHCListener.onResponse("Get top entities done"); } }, e -> { - logger.error("Failed to get top entities for detector " + adTask.getId(), e); + logger.error("Failed to get top entities for detector " + adTask.getConfigId(), e); internalHCListener.onFailure(e); }); // using the original context in listener as user roles have no permissions for internal operations like fetching a @@ -487,6 +483,7 @@ private void searchTopEntitiesForSingleCategoryHC( // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. adTask.getUser(), client, + context, searchResponseListener ); } @@ -510,7 +507,7 @@ public void forwardOrExecuteADTask( ) { try { checkIfADTaskCancelledAndCleanupCache(adTask); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); AnomalyDetector detector = adTask.getDetector(); boolean isHCDetector = detector.isHighCardinality(); if (isHCDetector) { @@ -527,7 +524,7 @@ public void forwardOrExecuteADTask( }); // This is to handle retry case. To retry entity, we need to get the old entity task created before. Entity entity = adTaskManager.parseEntityFromString(entityString, adTask); - String parentTaskId = adTask.getTaskType().equals(ADTaskType.HISTORICAL_HC_ENTITY.name()) + String parentTaskId = adTask.getTaskType().equals(ADTaskType.AD_HISTORICAL_HC_ENTITY.name()) ? adTask.getParentTaskId() // For HISTORICAL_HC_ENTITY task, return its parent task id : adTask.getTaskId(); // For HISTORICAL_HC_DETECTOR task, its task id is parent task id adTaskManager @@ -535,7 +532,7 @@ public void forwardOrExecuteADTask( detectorId, parentTaskId, entity, - ImmutableList.of(ADTaskType.HISTORICAL_HC_ENTITY), + ImmutableList.of(ADTaskType.AD_HISTORICAL_HC_ENTITY), existingEntityTask -> { if (existingEntityTask.isPresent()) { // retry failed entity caused by limit exceed exception // TODO: if task failed due to limit exceed exception in half way, resume from the break point or just clear @@ -559,14 +556,14 @@ public void forwardOrExecuteADTask( logger.info("Create entity task for entity:{}", entityString); Instant now = Instant.now(); ADTask adEntityTask = new ADTask.Builder() - .detectorId(adTask.getId()) + .detectorId(adTask.getConfigId()) .detector(detector) .isLatest(true) - .taskType(ADTaskType.HISTORICAL_HC_ENTITY.name()) + .taskType(ADTaskType.AD_HISTORICAL_HC_ENTITY.name()) .executionStartTime(now) .taskProgress(0.0f) .initProgress(0.0f) - .state(ADTaskState.INIT.name()) + .state(TaskState.INIT.name()) .initProgress(0.0f) .lastUpdateTime(now) .startedBy(adTask.getStartedBy()) @@ -576,7 +573,7 @@ public void forwardOrExecuteADTask( .entity(entity) .parentTaskId(parentTaskId) .build(); - adTaskManager.createADTaskDirectly(adEntityTask, r -> { + adTaskManager.createTaskDirectly(adEntityTask, r -> { adEntityTask.setTaskId(r.getId()); ActionListener workerNodeResponseListener = workerNodeResponseListener( adEntityTask, @@ -593,15 +590,15 @@ public void forwardOrExecuteADTask( ); } else { Map updatedFields = new HashMap<>(); - updatedFields.put(STATE_FIELD, ADTaskState.INIT.name()); - updatedFields.put(INIT_PROGRESS_FIELD, 0.0f); + updatedFields.put(TimeSeriesTask.STATE_FIELD, TaskState.INIT.name()); + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f); ActionListener workerNodeResponseListener = workerNodeResponseListener( adTask, transportService, listener ); adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), updatedFields, ActionListener @@ -637,7 +634,7 @@ private ActionListener workerNodeResponseListener( if (adTask.isEntityTask()) { // When reach this line, the entity task already been put into worker node's cache. // Then it's safe to move entity from temp entities queue to running entities queue. - adTaskCacheManager.moveToRunningEntity(adTask.getId(), adTaskManager.convertEntityToString(adTask)); + adTaskCacheManager.moveToRunningEntity(adTask.getConfigId(), adTaskManager.convertEntityToString(adTask)); } startNewEntityTaskLane(adTask, transportService); }, e -> { @@ -648,7 +645,7 @@ private ActionListener workerNodeResponseListener( if (adTask.getDetector().isHighCardinality()) { // Entity task done on worker node. Send entity task done message to coordinating node to poll next entity. adTaskManager.entityTaskDone(adTask, e, transportService); - if (adTaskCacheManager.getAvailableNewEntityTaskLanes(adTask.getId()) > 0) { + if (adTaskCacheManager.getAvailableNewEntityTaskLanes(adTask.getConfigId()) > 0) { // When reach this line, it means entity task failed to start on worker node // Sleep some time before starting new task lane. threadPool @@ -697,14 +694,14 @@ private void forwardOrExecuteEntityTask( // start new entity task lane private synchronized void startNewEntityTaskLane(ADTask adTask, TransportService transportService) { - if (adTask.getDetector().isHighCardinality() && adTaskCacheManager.getAndDecreaseEntityTaskLanes(adTask.getId()) > 0) { - logger.debug("start new task lane for detector {}", adTask.getId()); + if (adTask.getDetector().isHighCardinality() && adTaskCacheManager.getAndDecreaseEntityTaskLanes(adTask.getConfigId()) > 0) { + logger.debug("start new task lane for detector {}", adTask.getConfigId()); forwardOrExecuteADTask(adTask, transportService, getInternalHCDelegatedListener(adTask)); } } private void dispatchTask(ADTask adTask, ActionListener listener) { - hashRing.getNodesWithSameLocalAdVersion(dataNodes -> { + hashRing.getNodesWithSameLocalVersion(dataNodes -> { ADStatsRequest adStatsRequest = new ADStatsRequest(dataNodes); adStatsRequest.addAll(ImmutableSet.of(AD_EXECUTING_BATCH_TASK_COUNT.getName(), JVM_HEAP_USAGE.getName())); @@ -720,10 +717,10 @@ private void dispatchTask(ADTask adTask, ActionListener listener) .append(DEFAULT_JVM_HEAP_USAGE_THRESHOLD) .append("%. ") .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) - .append(adTask.getId()); + .append(adTask.getConfigId()); String errorMessage = errorMessageBuilder.toString(); logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); - listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } candidateNodeResponse = candidateNodeResponse @@ -733,10 +730,10 @@ private void dispatchTask(ADTask adTask, ActionListener listener) if (candidateNodeResponse.size() == 0) { StringBuilder errorMessageBuilder = new StringBuilder("All nodes' executing batch tasks exceeds limitation ") .append(NO_ELIGIBLE_NODE_TO_RUN_DETECTOR) - .append(adTask.getId()); + .append(adTask.getConfigId()); String errorMessage = errorMessageBuilder.toString(); logger.warn(errorMessage + ", task id " + adTask.getTaskId() + ", " + adTask.getTaskType()); - listener.onFailure(new LimitExceededException(adTask.getId(), errorMessage)); + listener.onFailure(new LimitExceededException(adTask.getConfigId(), errorMessage)); return; } Optional targetNode = candidateNodeResponse @@ -797,7 +794,7 @@ public void startADBatchTaskOnWorkerNode( private ActionListener internalBatchTaskListener(ADTask adTask, TransportService transportService) { String taskId = adTask.getTaskId(); String detectorTaskId = adTask.getDetectorLevelTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); ActionListener listener = ActionListener.wrap(response -> { // If batch task finished normally, remove task from cache and decrease executing task count by 1. adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); @@ -808,11 +805,11 @@ private ActionListener internalBatchTaskListener(ADTask adTask, Transpor .cleanDetectorCache( adTask, transportService, - () -> adTaskManager.updateADTask(taskId, ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())) + () -> adTaskManager.updateTask(taskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())) ); } else { // Set entity task as FINISHED here - adTaskManager.updateADTask(adTask.getTaskId(), ImmutableMap.of(STATE_FIELD, ADTaskState.FINISHED.name())); + adTaskManager.updateTask(adTask.getTaskId(), ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.FINISHED.name())); adTaskManager.entityTaskDone(adTask, null, transportService); } }, e -> { @@ -845,7 +842,7 @@ private void handleException(ADTask adTask, Exception e) { adStats.getStat(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName()).increment(); } // Handle AD task exception - adTaskManager.handleADTaskException(adTask, e); + adTaskManager.handleTaskException(adTask, e); } private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener internalListener) { @@ -864,7 +861,7 @@ private void executeADBatchTaskOnWorkerNode(ADTask adTask, ActionListener internalListener) { try { adTaskManager - .updateADTask( + .updateTask( adTask.getTaskId(), ImmutableMap .of( - STATE_FIELD, - ADTaskState.INIT.name(), - CURRENT_PIECE_FIELD, + TimeSeriesTask.STATE_FIELD, + TaskState.INIT.name(), + TimeSeriesTask.CURRENT_PIECE_FIELD, adTask.getDetectionDateRange().getStartTime().toEpochMilli(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 0.0f, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, 0.0f, - WORKER_NODE_FIELD, + TimeSeriesTask.WORKER_NODE_FIELD, clusterService.localNode().getId() ), ActionListener.wrap(r -> { @@ -919,7 +916,7 @@ private void runFirstPiece(ADTask adTask, Instant executeStartTime, ActionListen interval, dataStartTime, dataEndTime, - adTask.getId(), + adTask.getConfigId(), adTask.getTaskId() ); getFeatureData( @@ -969,7 +966,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons double maxValue = maxAgg.getValue(); // If time field not exist or there is no value, will return infinity value if (minValue == Double.POSITIVE_INFINITY) { - internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the time field")); + internalListener.onFailure(new ResourceNotFoundException(adTask.getConfigId(), "There is no data in the time field")); return; } long interval = ((IntervalTimeConfiguration) adTask.getDetector().getInterval()).toDuration().toMillis(); @@ -981,7 +978,8 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons long maxDate = (long) maxValue; if (minDate >= dataEndTime || maxDate <= dataStartTime) { - internalListener.onFailure(new ResourceNotFoundException(adTask.getId(), "There is no data in the detection date range")); + internalListener + .onFailure(new ResourceNotFoundException(adTask.getConfigId(), "There is no data in the detection date range")); return; } if (minDate > dataStartTime) { @@ -995,7 +993,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons dataStartTime = dataStartTime - dataStartTime % interval; dataEndTime = dataEndTime - dataEndTime % interval; logger.debug("adjusted date range: start: {}, end: {}, taskId: {}", dataStartTime, dataEndTime, taskId); - if ((dataEndTime - dataStartTime) < NUM_MIN_SAMPLES * interval) { + if ((dataEndTime - dataStartTime) < TimeSeriesSettings.NUM_MIN_SAMPLES * interval) { internalListener.onFailure(new TimeSeriesException("There is not enough data to train model").countedInStats(false)); return; } @@ -1010,6 +1008,7 @@ private void getDateRangeOfSourceData(ADTask adTask, BiConsumer cons // user is the one who started historical detector. Read AnomalyDetectorJobTransportAction.doExecute. adTask.getUser(), client, + context, searchResponseListener ); } @@ -1096,7 +1095,7 @@ private void detectAnomaly( ? "No full shingle in current detection window" : "No data in current detection window"; AnomalyResult anomalyResult = new AnomalyResult( - adTask.getId(), + adTask.getConfigId(), adTask.getDetectorLevelTaskId(), featureData, Instant.ofEpochMilli(intervalEndTime - interval), @@ -1122,7 +1121,7 @@ private void detectAnomaly( AnomalyResult anomalyResult = AnomalyResult .fromRawTRCFResult( - adTask.getId(), + adTask.getConfigId(), adTask.getDetector().getIntervalInMilliseconds(), adTask.getDetectorLevelTaskId(), score, @@ -1227,10 +1226,12 @@ private void storeAnomalyResultAndRunNextPiece( false ); + String detectorId = adTask.getConfigId(); anomalyResultBulkIndexHandler - .bulkIndexAnomalyResult( + .bulk( resultIndex, anomalyResults, + detectorId, runBefore == null ? actionListener : ActionListener.runBefore(actionListener, runBefore) ); } @@ -1244,14 +1245,14 @@ private void runNextPiece( ActionListener internalListener ) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String detectorTaskId = adTask.getDetectorLevelTaskId(); float initProgress = calculateInitProgress(taskId); - String taskState = initProgress >= 1.0f ? ADTaskState.RUNNING.name() : ADTaskState.INIT.name(); + String taskState = initProgress >= 1.0f ? TaskState.RUNNING.name() : TaskState.INIT.name(); logger.debug("Init progress: {}, taskState:{}, task id: {}", initProgress, taskState, taskId); if (initProgress >= 1.0f && adTask.isEntityTask()) { - updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), ADTaskState.RUNNING.name()); + updateDetectorLevelTaskState(detectorId, adTask.getParentTaskId(), TaskState.RUNNING.name()); } if (pieceStartTime < dataEndTime) { @@ -1271,17 +1272,17 @@ private void runNextPiece( float taskProgress = (float) (pieceStartTime - dataStartTime) / (dataEndTime - dataStartTime); logger.debug("Task progress: {}, task id:{}, detector id:{}", taskProgress, taskId, detectorId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, taskState, - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, pieceStartTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, taskProgress, - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress ), ActionListener @@ -1304,20 +1305,20 @@ private void runNextPiece( logger.info("AD task finished for detector {}, task id: {}", detectorId, taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); adTaskManager - .updateADTask( + .updateTask( taskId, ImmutableMap .of( - CURRENT_PIECE_FIELD, + TimeSeriesTask.CURRENT_PIECE_FIELD, dataEndTime, - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0f, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli(), - INIT_PROGRESS_FIELD, + TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress, - STATE_FIELD, - ADTaskState.FINISHED + TimeSeriesTask.STATE_FIELD, + TaskState.FINISHED ), ActionListener.wrap(r -> internalListener.onResponse("task execution done"), e -> internalListener.onFailure(e)) ); @@ -1326,7 +1327,7 @@ private void runNextPiece( private void updateDetectorLevelTaskState(String detectorId, String detectorTaskId, String newState) { ExecutorFunction function = () -> adTaskManager - .updateADTask(detectorTaskId, ImmutableMap.of(STATE_FIELD, newState), ActionListener.wrap(r -> { + .updateTask(detectorTaskId, ImmutableMap.of(TimeSeriesTask.STATE_FIELD, newState), ActionListener.wrap(r -> { logger.info("Updated HC detector task: {} state as: {} for detector: {}", detectorTaskId, newState, detectorId); adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); }, e -> { logger.error("Failed to update HC detector task: {} for detector: {}", detectorTaskId, detectorId); })); @@ -1350,14 +1351,14 @@ private float calculateInitProgress(String taskId) { if (rcf == null) { return 0.0f; } - float initProgress = (float) rcf.getTotalUpdates() / NUM_MIN_SAMPLES; + float initProgress = (float) rcf.getTotalUpdates() / TimeSeriesSettings.NUM_MIN_SAMPLES; logger.debug("RCF total updates {} for task {}", rcf.getTotalUpdates(), taskId); return initProgress > 1.0f ? 1.0f : initProgress; } private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String detectorTaskId = adTask.getDetectorLevelTaskId(); // refresh latest HC task run time adTaskCacheManager.refreshLatestHCTaskRunTime(detectorId); @@ -1379,7 +1380,7 @@ private void checkIfADTaskCancelledAndCleanupCache(ADTask adTask) { String cancelledBy = adTaskCacheManager.getCancelledBy(taskId); adTaskCacheManager.remove(taskId, detectorId, detectorTaskId); if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId) - && isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { + && ParseUtils.isNullOrEmpty(adTaskCacheManager.getTasksOfDetector(detectorId))) { // Clean up historical task cache for HC detector on worker node if no running entity task. logger.info("All AD task cancelled, cleanup historical task cache for detector {}", detectorId); adTaskCacheManager.removeHistoricalTaskCache(detectorId); diff --git a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java index 91f00b4cd..7f4c70f81 100644 --- a/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java +++ b/src/main/java/org/opensearch/ad/task/ADHCBatchTaskRunState.java @@ -13,7 +13,7 @@ import java.time.Instant; -import org.opensearch.ad.model.ADTaskState; +import org.opensearch.timeseries.model.TaskState; /** * Cache HC batch task running state on coordinating and worker node. @@ -32,7 +32,7 @@ public class ADHCBatchTaskRunState { private Long cancelledTimeInMillis; public ADHCBatchTaskRunState() { - this.detectorTaskState = ADTaskState.INIT.name(); + this.detectorTaskState = TaskState.INIT.name(); } public String getDetectorTaskState() { diff --git a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java index 0df994963..ddd85229d 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskCacheManager.java @@ -11,13 +11,10 @@ package org.opensearch.ad.task; -import static org.opensearch.ad.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CACHED_DELETED_TASKS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_TREES; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; +import static org.opensearch.timeseries.MemoryTracker.Origin.HISTORICAL_SINGLE_ENTITY_DETECTOR; import java.time.Instant; import java.util.ArrayList; @@ -35,29 +32,27 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; -import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.MemoryTracker; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.model.Entity; -import org.opensearch.transport.TransportService; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.util.ParseUtils; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; -public class ADTaskCacheManager { +public class ADTaskCacheManager extends TaskCacheManager { private final Logger logger = LogManager.getLogger(ADTaskCacheManager.class); private volatile Integer maxAdBatchTaskPerNode; - private volatile Integer maxCachedDeletedTask; private final MemoryTracker memoryTracker; private final int numberSize = 8; public static final int TASK_RETRY_LIMIT = 3; @@ -89,19 +84,6 @@ public class ADTaskCacheManager { *

Key: detector id

*/ private Map detectorTaskSlotLimit; - /** - * This field is to cache all realtime tasks on coordinating node. - *

Node: coordinating node

- *

Key is detector id

- */ - private Map realtimeTaskCaches; - /** - * This field is to cache all deleted detector level tasks on coordinating node. - * Will try to clean up child task and AD result later. - *

Node: coordinating node

- * Check {@link ADTaskManager#cleanChildTasksAndADResultsOfDeletedTask()} - */ - private Queue deletedDetectorTasks; // =================================================================== // Fields below are caches on worker node @@ -145,16 +127,13 @@ public class ADTaskCacheManager { * @param memoryTracker AD memory tracker */ public ADTaskCacheManager(Settings settings, ClusterService clusterService, MemoryTracker memoryTracker) { + super(settings, clusterService); this.maxAdBatchTaskPerNode = MAX_BATCH_TASK_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_BATCH_TASK_PER_NODE, it -> maxAdBatchTaskPerNode = it); - this.maxCachedDeletedTask = MAX_CACHED_DELETED_TASKS.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CACHED_DELETED_TASKS, it -> maxCachedDeletedTask = it); this.batchTaskCaches = new ConcurrentHashMap<>(); this.memoryTracker = memoryTracker; this.detectorTasks = new ConcurrentHashMap<>(); this.hcBatchTaskCaches = new ConcurrentHashMap<>(); - this.realtimeTaskCaches = new ConcurrentHashMap<>(); - this.deletedDetectorTasks = new ConcurrentLinkedQueue<>(); this.deletedDetectors = new ConcurrentLinkedQueue<>(); this.detectorTaskSlotLimit = new ConcurrentHashMap<>(); this.hcBatchTaskRunState = new ConcurrentHashMap<>(); @@ -171,7 +150,7 @@ public ADTaskCacheManager(Settings settings, ClusterService clusterService, Memo */ public synchronized void add(ADTask adTask) { String taskId = adTask.getTaskId(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); if (contains(taskId)) { throw new DuplicateTaskException(DETECTOR_IS_RUNNING); } @@ -212,7 +191,7 @@ public synchronized void add(String detectorId, ADTask adTask) { } logger.info("add detector in running detector cache, detectorId: {}, taskId: {}", detectorId, adTask.getTaskId()); this.detectorTasks.put(detectorId, adTask.getTaskId()); - if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { + if (ADTaskType.AD_HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { ADHCBatchTaskCache adhcBatchTaskCache = new ADHCBatchTaskCache(); this.hcBatchTaskCaches.put(detectorId, adhcBatchTaskCache); } @@ -354,8 +333,8 @@ private long calculateADTaskCacheSize(ADTask adTask) { return memoryTracker .estimateTRCFModelSize( dimension, - NUM_TREES, - AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, detector.getShingleSize().intValue(), false ) + shingleMemorySize(detector.getShingleSize(), detector.getEnabledFeatureIds().size()); @@ -373,8 +352,7 @@ public long getModelSize(String taskId) { RandomCutForest rcfForest = tRCF.getForest(); int dimensions = rcfForest.getDimensions(); int numberOfTrees = rcfForest.getNumberOfTrees(); - return memoryTracker - .estimateTRCFModelSize(dimensions, numberOfTrees, AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, 1, false); + return memoryTracker.estimateTRCFModelSize(dimensions, numberOfTrees, TimeSeriesSettings.BATCH_BOUNDING_BOX_CACHE_RATIO, 1, false); } /** @@ -483,7 +461,7 @@ public ADTaskCancellationState cancelByDetectorId(String detectorId, String dete taskStateCache.setCancelReason(reason); taskStateCache.setCancelledBy(userName); - if (isNullOrEmpty(taskCaches)) { + if (ParseUtils.isNullOrEmpty(taskCaches)) { return ADTaskCancellationState.NOT_FOUND; } @@ -1012,156 +990,6 @@ public void clearPendingEntities(String detectorId) { } } - /** - * Check if realtime task field value change needed or not by comparing with cache. - * 1. If new field value is null, will consider changed needed to this field. - * 2. will consider the real time task change needed if - * 1) init progress is larger or the old init progress is null, or - * 2) if the state is different, and it is not changing from running to init. - * for other fields, as long as field values changed, will consider the realtime - * task change needed. We did this so that the init progress or state won't go backwards. - * 3. If realtime task cache not found, will consider the realtime task change needed. - * - * @param detectorId detector id - * @param newState new task state - * @param newInitProgress new init progress - * @param newError new error - * @return true if realtime task change needed. - */ - public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Float newInitProgress, String newError) { - if (realtimeTaskCaches.containsKey(detectorId)) { - ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); - boolean stateChangeNeeded = false; - String oldState = realtimeTaskCache.getState(); - if (newState != null - && !newState.equals(oldState) - && !(ADTaskState.INIT.name().equals(newState) && ADTaskState.RUNNING.name().equals(oldState))) { - stateChangeNeeded = true; - } - boolean initProgressChangeNeeded = false; - Float existingProgress = realtimeTaskCache.getInitProgress(); - if (newInitProgress != null - && !newInitProgress.equals(existingProgress) - && (existingProgress == null || newInitProgress > existingProgress)) { - initProgressChangeNeeded = true; - } - boolean errorChanged = false; - if (newError != null && !newError.equals(realtimeTaskCache.getError())) { - errorChanged = true; - } - if (stateChangeNeeded || initProgressChangeNeeded || errorChanged) { - return true; - } - return false; - } else { - return true; - } - } - - /** - * Update realtime task cache with new field values. If realtime task cache exist, update it - * directly if task is not done; if task is done, remove the detector's realtime task cache. - * - * If realtime task cache doesn't exist, will do nothing. Next realtime job run will re-init - * realtime task cache when it finds task cache not inited yet. - * Check {@link ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache(String, AnomalyDetector, TransportService, ActionListener)}, - * {@link ADTaskManager#updateLatestRealtimeTaskOnCoordinatingNode(String, String, Long, Long, String, ActionListener)} - * - * @param detectorId detector id - * @param newState new task state - * @param newInitProgress new init progress - * @param newError new error - */ - public void updateRealtimeTaskCache(String detectorId, String newState, Float newInitProgress, String newError) { - ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); - if (realtimeTaskCache != null) { - if (newState != null) { - realtimeTaskCache.setState(newState); - } - if (newInitProgress != null) { - realtimeTaskCache.setInitProgress(newInitProgress); - } - if (newError != null) { - realtimeTaskCache.setError(newError); - } - if (newState != null && !ADTaskState.NOT_ENDED_STATES.contains(newState)) { - // If task is done, will remove its realtime task cache. - logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); - removeRealtimeTaskCache(detectorId); - } - } else { - logger.debug("Realtime task cache is not inited yet for detector {}", detectorId); - } - } - - public void initRealtimeTaskCache(String detectorId, long detectorIntervalInMillis) { - realtimeTaskCaches.put(detectorId, new ADRealtimeTaskCache(null, null, null, detectorIntervalInMillis)); - logger.debug("Realtime task cache inited"); - } - - public void refreshRealtimeJobRunTime(String detectorId) { - ADRealtimeTaskCache taskCache = realtimeTaskCaches.get(detectorId); - if (taskCache != null) { - taskCache.setLastJobRunTime(Instant.now().toEpochMilli()); - } - } - - /** - * Get detector IDs from realtime task cache. - * @return array of detector id - */ - public String[] getDetectorIdsInRealtimeTaskCache() { - return realtimeTaskCaches.keySet().toArray(new String[0]); - } - - /** - * Remove detector's realtime task from cache. - * @param detectorId detector id - */ - public void removeRealtimeTaskCache(String detectorId) { - if (realtimeTaskCaches.containsKey(detectorId)) { - logger.info("Delete realtime cache for detector {}", detectorId); - realtimeTaskCaches.remove(detectorId); - } - } - - public ADRealtimeTaskCache getRealtimeTaskCache(String detectorId) { - return realtimeTaskCaches.get(detectorId); - } - - /** - * Clear realtime task cache. - */ - public void clearRealtimeTaskCache() { - realtimeTaskCaches.clear(); - } - - /** - * Add deleted task's id to deleted detector tasks queue. - * @param taskId task id - */ - public void addDeletedDetectorTask(String taskId) { - if (deletedDetectorTasks.size() < maxCachedDeletedTask) { - deletedDetectorTasks.add(taskId); - } - } - - /** - * Check if deleted task queue has items. - * @return true if has deleted detector task in cache - */ - public boolean hasDeletedDetectorTask() { - return !deletedDetectorTasks.isEmpty(); - } - - /** - * Poll one deleted detector task. - * @return task id - */ - public String pollDeletedDetectorTask() { - return this.deletedDetectorTasks.poll(); - } - /** * Add deleted detector's id to deleted detector queue. * @param detectorId detector id @@ -1317,7 +1145,7 @@ public void cleanExpiredHCBatchTaskRunStates() { for (Map.Entry> detectorRunStates : hcBatchTaskRunState.entrySet()) { List taskIdOfExpiredStates = new ArrayList<>(); String detectorId = detectorRunStates.getKey(); - boolean noRunningTask = isNullOrEmpty(getTasksOfDetector(detectorId)); + boolean noRunningTask = ParseUtils.isNullOrEmpty(getTasksOfDetector(detectorId)); Map taskRunStates = detectorRunStates.getValue(); if (taskRunStates == null) { // If detector's task run state is null, add detector id to detectorIdOfEmptyStates and remove it from @@ -1362,32 +1190,4 @@ public void cleanExpiredHCBatchTaskRunStates() { } } - /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. - * To avoid repeated query when there is no data, record whether we have done that or not. - * @param id detector id - */ - public void markResultIndexQueried(String id) { - ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); - // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it - // cannot be found. If the cache is empty, we will return early and wait it for it to be - // initialized. - if (realtimeTaskCache != null) { - realtimeTaskCache.setQueriedResultIndex(true); - } - } - - /** - * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. - * - * @param id detector id - * @return whether we have queried result index or not. - */ - public boolean hasQueriedResultIndex(String id) { - ADRealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); - if (realtimeTaskCache != null) { - return realtimeTaskCache.hasQueriedResultIndex(); - } - return false; - } } diff --git a/src/main/java/org/opensearch/ad/task/ADTaskManager.java b/src/main/java/org/opensearch/ad/task/ADTaskManager.java index c482b0ba8..9c1fe0cef 100644 --- a/src/main/java/org/opensearch/ad/task/ADTaskManager.java +++ b/src/main/java/org/opensearch/ad/task/ADTaskManager.java @@ -12,58 +12,34 @@ package org.opensearch.ad.task; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; import static org.opensearch.ad.constant.ADCommonMessages.DETECTOR_IS_RUNNING; import static org.opensearch.ad.constant.ADCommonMessages.EXCEED_HISTORICAL_ANALYSIS_LIMIT; import static org.opensearch.ad.constant.ADCommonMessages.HC_DETECTOR_TASK_IS_UPDATING; import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN; -import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; -import static org.opensearch.ad.model.ADTaskState.NOT_ENDED_STATES; import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; -import static org.opensearch.ad.model.ADTaskType.taskTypeToString; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; -import static org.opensearch.ad.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; -import static org.opensearch.ad.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; -import static org.opensearch.ad.util.ExceptionUtil.getErrorMessage; -import static org.opensearch.ad.util.ExceptionUtil.getShardsFailure; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD; -import static org.opensearch.timeseries.util.ParseUtils.isNullOrEmpty; -import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_DETECTOR_ASSIGNED_BATCH_TASK_SLOT_COUNT; +import static org.opensearch.timeseries.stats.InternalStatNames.AD_USED_BATCH_TASK_SLOT_COUNT; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.ArrayList; -import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Locale; @@ -74,44 +50,31 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Semaphore; import java.util.function.BiConsumer; -import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; -import org.apache.lucene.search.join.ScoreMode; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; -import org.opensearch.action.bulk.BulkAction; -import org.opensearch.action.bulk.BulkItemResponse; -import org.opensearch.action.bulk.BulkRequest; -import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADEntityTaskProfile; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; import org.opensearch.ad.transport.ADBatchAnomalyResultAction; import org.opensearch.ad.transport.ADBatchAnomalyResultRequest; import org.opensearch.ad.transport.ADCancelTaskAction; @@ -122,7 +85,6 @@ import org.opensearch.ad.transport.ADTaskProfileAction; import org.opensearch.ad.transport.ADTaskProfileNodeResponse; import org.opensearch.ad.transport.ADTaskProfileRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskRequest; import org.opensearch.client.Client; @@ -130,7 +92,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; @@ -143,31 +104,37 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; -import org.opensearch.index.query.NestedQueryBuilder; -import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; -import org.opensearch.index.reindex.UpdateByQueryAction; -import org.opensearch.index.reindex.UpdateByQueryRequest; -import org.opensearch.script.Script; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.SortOrder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TaskCancelledException; import org.opensearch.timeseries.common.exception.TimeSeriesException; -import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportService; @@ -178,28 +145,21 @@ /** * Manage AD task. */ -public class ADTaskManager { +public class ADTaskManager extends TaskManager { public static final String AD_TASK_LEAD_NODE_MODEL_ID = "ad_task_lead_node_model_id"; public static final String AD_TASK_MAINTAINENCE_NODE_MODEL_ID = "ad_task_maintainence_node_model_id"; // HC batch task timeout after 10 minutes if no update after last known run time. public static final int HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS = 600_000; - private final Logger logger = LogManager.getLogger(this.getClass()); + public final Logger logger = LogManager.getLogger(this.getClass()); static final String STATE_INDEX_NOT_EXIST_MSG = "State index does not exist."; private final Set retryableErrors = ImmutableSet.of(EXCEED_HISTORICAL_ANALYSIS_LIMIT, NO_ELIGIBLE_NODE_TO_RUN_DETECTOR); - private final Client client; - private final ClusterService clusterService; - private final NamedXContentRegistry xContentRegistry; - private final ADIndexManagement detectionIndices; + private final DiscoveryNodeFilterer nodeFilter; - private final ADTaskCacheManager adTaskCacheManager; private final HashRing hashRing; - private volatile Integer maxOldAdTaskDocsPerDetector; private volatile Integer pieceIntervalSeconds; private volatile boolean deleteADResultWhenDeleteDetector; private volatile TransportRequestOptions transportRequestOptions; - private final ThreadPool threadPool; - private static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; private final Semaphore checkingTaskSlot; private volatile Integer maxAdBatchTaskPerNode; @@ -217,21 +177,30 @@ public ADTaskManager( DiscoveryNodeFilterer nodeFilter, HashRing hashRing, ADTaskCacheManager adTaskCacheManager, - ThreadPool threadPool + ThreadPool threadPool, + NodeStateManager nodeStateManager ) { - this.client = client; - this.xContentRegistry = xContentRegistry; - this.detectionIndices = detectionIndices; + super( + adTaskCacheManager, + clusterService, + client, + DETECTION_STATE_INDEX, + ADTaskType.REALTIME_TASK_TYPES, + detectionIndices, + nodeStateManager, + AnalysisType.AD, + xContentRegistry, + DETECTOR_ID_FIELD, + MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, + settings, + threadPool, + ALL_AD_RESULTS_INDEX_PATTERN, + AD_BATCH_TASK_THREAD_POOL_NAME + ); + this.nodeFilter = nodeFilter; - this.clusterService = clusterService; - this.adTaskCacheManager = adTaskCacheManager; this.hashRing = hashRing; - this.maxOldAdTaskDocsPerDetector = MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, it -> maxOldAdTaskDocsPerDetector = it); - this.pieceIntervalSeconds = BATCH_TASK_PIECE_INTERVAL_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(BATCH_TASK_PIECE_INTERVAL_SECONDS, it -> pieceIntervalSeconds = it); @@ -251,12 +220,12 @@ public ADTaskManager( transportRequestOptions = TransportRequestOptions .builder() .withType(TransportRequestOptions.Type.REG) - .withTimeout(REQUEST_TIMEOUT.get(settings)) + .withTimeout(AD_REQUEST_TIMEOUT.get(settings)) .build(); clusterService .getClusterSettings() .addSettingsUpdateConsumer( - REQUEST_TIMEOUT, + AD_REQUEST_TIMEOUT, it -> { transportRequestOptions = TransportRequestOptions .builder() @@ -265,85 +234,11 @@ public ADTaskManager( .build(); } ); - this.threadPool = threadPool; + this.checkingTaskSlot = new Semaphore(1); this.scaleEntityTaskLane = new Semaphore(1); } - /** - * Start detector. Will create schedule job for realtime detector, - * and start AD task for historical detector. - * - * @param detectorId detector id - * @param detectionDateRange historical analysis date range - * @param handler anomaly detector job action handler - * @param user user - * @param transportService transport service - * @param context thread context - * @param listener action listener - */ - public void startDetector( - String detectorId, - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ThreadContext.StoredContext context, - ActionListener listener - ) { - // upgrade index mapping of AD default indices - detectionIndices.update(); - - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - - // Validate if detector is ready to start. Will return null if ready to start. - String errorMessage = validateDetector(detector.get()); - if (errorMessage != null) { - listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); - return; - } - String resultIndex = detector.get().getCustomResultIndex(); - if (resultIndex == null) { - startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector); - return; - } - context.restore(); - detectionIndices - .initCustomResultIndexAndExecute( - resultIndex, - () -> startRealtimeOrHistoricalDetection(detectionDateRange, handler, user, transportService, listener, detector), - listener - ); - - }, listener); - } - - private void startRealtimeOrHistoricalDetection( - DateRange detectionDateRange, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener, - Optional detector - ) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - if (detectionDateRange == null) { - // start realtime job - handler.startAnomalyDetectorJob(detector.get(), listener); - } else { - // start historical analysis task - forwardApplyForTaskSlotsRequestToLeadNode(detector.get(), detectionDateRange, user, transportService, listener); - } - } catch (Exception e) { - logger.error("Failed to stash context", e); - listener.onFailure(e); - } - } - /** * When AD receives start historical analysis request for a detector, will * 1. Forward to lead node to check available task slots first. @@ -358,15 +253,16 @@ private void startRealtimeOrHistoricalDetection( * @param transportService transport service * @param listener action listener */ - protected void forwardApplyForTaskSlotsRequestToLeadNode( - AnomalyDetector detector, + @Override + public void startHistorical( + Config config, DateRange detectionDateRange, User user, TransportService transportService, - ActionListener listener + ActionListener listener ) { ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest( - detector, + (AnomalyDetector) config, detectionDateRange, user, ADTaskAction.APPLY_FOR_TASK_SLOTS @@ -377,7 +273,7 @@ protected void forwardApplyForTaskSlotsRequestToLeadNode( public void forwardScaleTaskSlotRequestToLeadNode( ADTask adTask, TransportService transportService, - ActionListener listener + ActionListener listener ) { forwardRequestToLeadNode(new ForwardADTaskRequest(adTask, ADTaskAction.CHECK_AVAILABLE_TASK_SLOTS), transportService, listener); } @@ -385,9 +281,9 @@ public void forwardScaleTaskSlotRequestToLeadNode( public void forwardRequestToLeadNode( ForwardADTaskRequest forwardADTaskRequest, TransportService transportService, - ActionListener listener + ActionListener listener ) { - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(AD_TASK_LEAD_NODE_MODEL_ID, node -> { if (!node.isPresent()) { listener.onFailure(new ResourceNotFoundException("Can't find AD task lead node")); return; @@ -398,7 +294,7 @@ public void forwardRequestToLeadNode( ForwardADTaskAction.NAME, forwardADTaskRequest, transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); }, listener); } @@ -419,10 +315,10 @@ public void startHistoricalAnalysis( User user, int availableTaskSlots, TransportService transportService, - ActionListener listener + ActionListener listener ) { String detectorId = detector.getId(); - hashRing.buildAndGetOwningNodeWithSameLocalAdVersion(detectorId, owningNode -> { + hashRing.buildAndGetOwningNodeWithSameLocalVersion(detectorId, owningNode -> { if (!owningNode.isPresent()) { logger.debug("Can't find eligible node to run as AD task's coordinating node"); listener.onFailure(new OpenSearchStatusException("No eligible node to run detector", RestStatus.INTERNAL_SERVER_ERROR)); @@ -473,9 +369,9 @@ protected void forwardDetectRequestToCoordinatingNode( ADTaskAction adTaskAction, TransportService transportService, DiscoveryNode node, - ActionListener listener + ActionListener listener ) { - Version adVersion = hashRing.getAdVersion(node.getId()); + Version adVersion = hashRing.getVersion(node.getId()); transportService .sendRequest( node, @@ -484,7 +380,7 @@ protected void forwardDetectRequestToCoordinatingNode( // node, check ADTaskManager#cleanDetectorCache. new ForwardADTaskRequest(detector, detectionDateRange, user, adTaskAction, availableTaskSlots, adVersion), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -500,7 +396,7 @@ protected void forwardADTaskToCoordinatingNode( ADTask adTask, ADTaskAction adTaskAction, TransportService transportService, - ActionListener listener + ActionListener listener ) { logger.debug("Forward AD task to coordinating node, task id: {}, action: {}", adTask.getTaskId(), adTaskAction.name()); transportService @@ -509,7 +405,7 @@ protected void forwardADTaskToCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, adTaskAction), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -527,7 +423,7 @@ protected void forwardStaleRunningEntitiesToCoordinatingNode( ADTaskAction adTaskAction, TransportService transportService, List staleRunningEntity, - ActionListener listener + ActionListener listener ) { transportService .sendRequest( @@ -535,7 +431,7 @@ protected void forwardStaleRunningEntitiesToCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, adTaskAction, staleRunningEntity), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -559,7 +455,7 @@ public void checkTaskSlots( User user, ADTaskAction afterCheckAction, TransportService transportService, - ActionListener listener + ActionListener listener ) { String detectorId = detector.getId(); logger.debug("Start checking task slots for detector: {}, task action: {}", detectorId, afterCheckAction); @@ -574,11 +470,11 @@ public void checkTaskSlots( ); return; } - ActionListener wrappedActionListener = ActionListener.runAfter(listener, () -> { + ActionListener wrappedActionListener = ActionListener.runAfter(listener, () -> { checkingTaskSlot.release(1); logger.debug("Release checking task slot semaphore on lead node for detector {}", detectorId); }); - hashRing.getNodesWithSameLocalAdVersion(nodes -> { + hashRing.getNodesWithSameLocalVersion(nodes -> { int maxAdTaskSlots = nodes.length * maxAdBatchTaskPerNode; ADStatsRequest adStatsRequest = new ADStatsRequest(nodes); adStatsRequest @@ -656,7 +552,7 @@ private void forwardToCoordinatingNode( User user, ADTaskAction targetActionOfTaskSlotChecking, TransportService transportService, - ActionListener wrappedActionListener, + ActionListener wrappedActionListener, int approvedTaskSlots ) { switch (targetActionOfTaskSlotChecking) { @@ -669,7 +565,7 @@ private void forwardToCoordinatingNode( .info( "There are {} task slots available now to scale historical analysis task lane for detector {}", approvedTaskSlots, - adTask.getId() + adTask.getConfigId() ); scaleTaskLaneOnCoordinatingNode(adTask, approvedTaskSlots, transportService, wrappedActionListener); break; @@ -683,7 +579,7 @@ protected void scaleTaskLaneOnCoordinatingNode( ADTask adTask, int approvedTaskSlot, TransportService transportService, - ActionListener listener + ActionListener listener ) { DiscoveryNode coordinatingNode = getCoordinatingNode(adTask); transportService @@ -692,7 +588,7 @@ protected void scaleTaskLaneOnCoordinatingNode( ForwardADTaskAction.NAME, new ForwardADTaskRequest(adTask, approvedTaskSlot, ADTaskAction.SCALE_ENTITY_TASK_SLOTS), transportRequestOptions, - new ActionListenerResponseHandler<>(listener, AnomalyDetectorJobResponse::new) + new ActionListenerResponseHandler<>(listener, JobResponse::new) ); } @@ -707,423 +603,38 @@ private DiscoveryNode getCoordinatingNode(ADTask adTask) { } } if (targetNode == null) { - throw new ResourceNotFoundException(adTask.getId(), "AD task coordinating node not found"); + throw new ResourceNotFoundException(adTask.getConfigId(), "AD task coordinating node not found"); } return targetNode; } - /** - * Start anomaly detector. - * For historical analysis, this method will be called on coordinating node. - * For realtime task, we won't know AD job coordinating node until AD job starts. So - * this method will be called on vanilla node. - * - * Will init task index if not exist and write new AD task to index. If task index - * exists, will check if there is task running. If no running task, reset old task - * as not latest and clean old tasks which exceeds max old task doc limitation. - * Then find out node with least load and dispatch task to that node(worker node). - * - * @param detector anomaly detector - * @param detectionDateRange detection date range - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void startDetector( - AnomalyDetector detector, - DateRange detectionDateRange, - User user, - TransportService transportService, - ActionListener listener - ) { - try { - if (detectionIndices.doesStateIndexExist()) { - // If detection index exist, check if latest AD task is running - getAndExecuteOnLatestDetectorLevelTask(detector.getId(), getADTaskTypes(detectionDateRange), (adTask) -> { - if (!adTask.isPresent() || adTask.get().isDone()) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); - } - }, transportService, true, listener); - } else { - // If detection index doesn't exist, create index and execute detector. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { - if (r.isAcknowledged()) { - logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); - logger.warn(error); - listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, e -> { - if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { - updateLatestFlagOfOldTasksAndCreateNewTask(detector, detectionDateRange, user, listener); - } else { - logger.error("Failed to init anomaly detection state index", e); - listener.onFailure(e); - } - })); - } - } catch (Exception e) { - logger.error("Failed to start detector " + detector.getId(), e); - listener.onFailure(e); - } - } - - private ADTaskType getADTaskType(AnomalyDetector detector, DateRange detectionDateRange) { - if (detectionDateRange == null) { - return detector.isHighCardinality() ? ADTaskType.REALTIME_HC_DETECTOR : ADTaskType.REALTIME_SINGLE_ENTITY; - } else { - return detector.isHighCardinality() ? ADTaskType.HISTORICAL_HC_DETECTOR : ADTaskType.HISTORICAL_SINGLE_ENTITY; - } - } - - private List getADTaskTypes(DateRange detectionDateRange) { - return getADTaskTypes(detectionDateRange, false); - } - - /** - * Get list of task types. - * 1. If detection date range is null, will return all realtime task types - * 2. If detection date range is not null, will return all historical detector level tasks types - * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include - * HC entity level task type. - * @param detectionDateRange detection date range - * @param resetLatestTaskStateFlag reset latest task state or not - * @return list of AD task types - */ - private List getADTaskTypes(DateRange detectionDateRange, boolean resetLatestTaskStateFlag) { + @Override + protected TaskType getTaskType(Config config, DateRange detectionDateRange) { if (detectionDateRange == null) { - return REALTIME_TASK_TYPES; + return config.isHighCardinality() ? ADTaskType.AD_REALTIME_HC_DETECTOR : ADTaskType.AD_REALTIME_SINGLE_STREAM; } else { - if (resetLatestTaskStateFlag) { - // return all task types include HC entity task to make sure we can reset all tasks latest flag - return ALL_HISTORICAL_TASK_TYPES; - } else { - return HISTORICAL_DETECTOR_TASK_TYPES; - } - } - } - - /** - * Stop detector. - * For realtime detector, will set detector job as disabled. - * For historical detector, will set its AD task as cancelled. - * - * @param detectorId detector id - * @param historical stop historical analysis or not - * @param handler AD job action handler - * @param user user - * @param transportService transport service - * @param listener action listener - */ - public void stopDetector( - String detectorId, - boolean historical, - IndexAnomalyDetectorJobActionHandler handler, - User user, - TransportService transportService, - ActionListener listener - ) { - getDetector(detectorId, (detector) -> { - if (!detector.isPresent()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - if (historical) { - // stop historical analyis - getAndExecuteOnLatestDetectorLevelTask( - detectorId, - HISTORICAL_DETECTOR_TASK_TYPES, - (task) -> stopHistoricalAnalysis(detectorId, task, user, listener), - transportService, - false,// don't need to reset task state when stop detector - listener - ); - } else { - // stop realtime detector job - handler.stopAnomalyDetectorJob(detectorId, listener); - } - }, listener); - } - - /** - * Get anomaly detector and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void getDetector(String detectorId, Consumer> function, ActionListener listener) { - GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX, detectorId); - client.get(getRequest, ActionListener.wrap(response -> { - if (!response.isExists()) { - function.accept(Optional.empty()); - return; - } - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); - - function.accept(Optional.of(detector)); - } catch (Exception e) { - String message = "Failed to parse anomaly detector " + detectorId; - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - }, exception -> { - logger.error("Failed to get detector " + detectorId, exception); - listener.onFailure(exception); - })); - } - - /** - * Get latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestDetectorLevelTask( - String detectorId, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTask(detectorId, null, null, adTaskTypes, function, transportService, resetTaskState, listener); - } - - /** - * Get one latest AD task and execute consumer function. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param listener action listener - * @param action listener response type - */ - public void getAndExecuteOnLatestADTask( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - ActionListener listener - ) { - getAndExecuteOnLatestADTasks(detectorId, parentTaskId, entity, adTaskTypes, (taskList) -> { - if (taskList != null && taskList.size() > 0) { - function.accept(Optional.ofNullable(taskList.get(0))); - } else { - function.accept(Optional.empty()); - } - }, transportService, resetTaskState, 1, listener); - } - - /** - * Get latest AD tasks and execute consumer function. - * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data - * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is - * any stale running entities(entity exists in coordinating node cache but no task running on worker - * node) and clean up. - * [Important!] Make sure listener returns in function - * - * @param detectorId detector id - * @param parentTaskId parent task id - * @param entity entity value - * @param adTaskTypes AD task types - * @param function consumer function - * @param transportService transport service - * @param resetTaskState reset task state or not - * @param size return how many AD tasks - * @param listener action listener - * @param response type of action listener - */ - public void getAndExecuteOnLatestADTasks( - String detectorId, - String parentTaskId, - Entity entity, - List adTaskTypes, - Consumer> function, - TransportService transportService, - boolean resetTaskState, - int size, - ActionListener listener - ) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + return config.isHighCardinality() ? ADTaskType.AD_HISTORICAL_HC_DETECTOR : ADTaskType.AD_HISTORICAL_SINGLE_STREAM; } - if (adTaskTypes != null && adTaskTypes.size() > 0) { - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(adTaskTypes))); - } - if (entity != null && !isNullOrEmpty(entity.getAttributes())) { - String path = "entity"; - String entityKeyFieldName = path + ".name"; - String entityValueFieldName = path + ".value"; - - for (Map.Entry attribute : entity.getAttributes().entrySet()) { - BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); - TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); - TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); - - entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); - NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); - query.filter(nestedQueryBuilder); - } - } - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); - SearchRequest searchRequest = new SearchRequest(); - searchRequest.source(sourceBuilder); - searchRequest.indices(DETECTION_STATE_INDEX); - - client.search(searchRequest, ActionListener.wrap(r -> { - // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 - // getTotalHits will be null when we track_total_hits is false in the query request. - // Add more checking here to cover some unknown cases. - List adTasks = new ArrayList<>(); - if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { - // don't throw exception here as consumer functions need to handle missing task - // in different way. - function.accept(adTasks); - return; - } - - Iterator iterator = r.getHits().iterator(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - adTasks.add(adTask); - } catch (Exception e) { - String message = "Failed to parse AD task for detector " + detectorId + ", task id " + searchHit.getId(); - logger.error(message, e); - listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); - } - } - if (resetTaskState) { - resetLatestDetectorTaskState(adTasks, function, transportService, listener); - } else { - function.accept(adTasks); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.accept(new ArrayList<>()); - } else { - logger.error("Failed to search AD task for detector " + detectorId, e); - listener.onFailure(e); - } - })); } /** - * Reset latest detector task state. Will reset both historical and realtime tasks. - * [Important!] Make sure listener returns in function - * - * @param adTasks ad tasks - * @param function consumer function - * @param transportService transport service - * @param listener action listener - * @param response type of action listener - */ - private void resetLatestDetectorTaskState( - List adTasks, - Consumer> function, - TransportService transportService, - ActionListener listener - ) { - List runningHistoricalTasks = new ArrayList<>(); - List runningRealtimeTasks = new ArrayList<>(); - for (ADTask adTask : adTasks) { - if (!adTask.isEntityTask() && !adTask.isDone()) { - if (!adTask.isHistoricalTask()) { - // try to reset task state if realtime task is not ended - runningRealtimeTasks.add(adTask); - } else { - // try to reset task state if historical task not updated for 2 piece intervals - runningHistoricalTasks.add(adTask); - } - } - } - - resetHistoricalDetectorTaskState( - runningHistoricalTasks, - () -> resetRealtimeDetectorTaskState(runningRealtimeTasks, () -> function.accept(adTasks), transportService, listener), - transportService, - listener - ); - } - - private void resetRealtimeDetectorTaskState( - List runningRealtimeTasks, + * If resetTaskState is true, will collect latest task's profile data from all data nodes. If no data + * node running the latest task, will reset the task state as STOPPED; otherwise, check if there is + * any stale running entities(entity exists in coordinating node cache but no task running on worker + * node) and clean up. + */ + @Override + protected void resetHistoricalConfigTaskState( + List runningHistoricalTasks, ExecutorFunction function, TransportService transportService, ActionListener listener ) { - if (isNullOrEmpty(runningRealtimeTasks)) { + if (ParseUtils.isNullOrEmpty(runningHistoricalTasks)) { function.execute(); return; } - ADTask adTask = runningRealtimeTasks.get(0); - String detectorId = adTask.getId(); - GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); - client.get(getJobRequest, ActionListener.wrap(r -> { - if (r.isExists()) { - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); - if (!job.isEnabled()) { - logger.debug("AD job is disabled, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } else { - function.execute(); - } - } catch (IOException e) { - logger.error(" Failed to parse AD job " + detectorId, e); - listener.onFailure(e); - } - } else { - logger.debug("AD job is not found, reset realtime task as stopped for detector {}", detectorId); - resetTaskStateAsStopped(adTask, function, transportService, listener); - } - }, e -> { - logger.error("Fail to get AD realtime job for detector " + detectorId, e); - listener.onFailure(e); - })); - } - - private void resetHistoricalDetectorTaskState( - List runningHistoricalTasks, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - if (isNullOrEmpty(runningHistoricalTasks)) { - function.execute(); - return; - } - ADTask adTask = runningHistoricalTasks.get(0); + ADTask adTask = (ADTask) runningHistoricalTasks.get(0); // If AD task is still running, but its last updated time not refreshed for 2 piece intervals, we will get // task profile to check if it's really running. If task not running, reset state as STOPPED. // For example, ES process crashes, then all tasks running on it will stay as running. We can reset the task @@ -1140,10 +651,10 @@ private void resetHistoricalDetectorTaskState( logger.debug("Reset task state as stopped, task id: {}", adTask.getTaskId()); if (taskProfile.getTaskId() == null // This means coordinating node doesn't have HC detector cache && detector.isHighCardinality() - && !isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { + && !ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles())) { // If coordinating node restarted, HC detector cache on it will be gone. But worker node still // runs entity tasks, we'd better stop these entity tasks to clean up resource earlier. - stopHistoricalAnalysis(adTask.getId(), Optional.of(adTask), null, ActionListener.wrap(r -> { + stopHistoricalAnalysis(adTask.getConfigId(), Optional.of(adTask), null, ActionListener.wrap(r -> { logger.debug("Restop detector successfully"); resetTaskStateAsStopped(adTask, function, transportService, listener); }, e -> { @@ -1156,10 +667,11 @@ private void resetHistoricalDetectorTaskState( } else { function.execute(); // If still running, check if there is any stale running entities and clean them - if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { + if (ADTaskType.AD_HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { // Check if any running entity not run on worker node. If yes, we need to remove it // and poll next entity from pending entity queue and run it. - if (!isNullOrEmpty(taskProfile.getRunningEntities()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { + if (!ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { List runningTasksInCoordinatingNodeCache = new ArrayList<>(taskProfile.getRunningEntities()); List runningTasksOnWorkerNode = new ArrayList<>(); if (taskProfile.getEntityTaskProfiles() != null && taskProfile.getEntityTaskProfiles().size() > 0) { @@ -1204,8 +716,8 @@ private boolean isTaskStopped(String taskId, AnomalyDetector detector, ADTaskPro } if (detector.isHighCardinality() && taskProfile.getTotalEntitiesInited() - && isNullOrEmpty(taskProfile.getRunningEntities()) - && isNullOrEmpty(taskProfile.getEntityTaskProfiles()) + && ParseUtils.isNullOrEmpty(taskProfile.getRunningEntities()) + && ParseUtils.isNullOrEmpty(taskProfile.getEntityTaskProfiles()) && hcBatchTaskExpired(taskProfile.getLatestHCTaskRunTime())) { logger.debug("AD task not running for HC detector {}, task {}", detectorId, taskId); return true; @@ -1220,12 +732,8 @@ public boolean hcBatchTaskExpired(Long latestHCTaskRunTime) { return latestHCTaskRunTime + HC_BATCH_TASK_CACHE_TIMEOUT_IN_MILLIS < Instant.now().toEpochMilli(); } - private void stopHistoricalAnalysis( - String detectorId, - Optional adTask, - User user, - ActionListener listener - ) { + @Override + public void stopHistoricalAnalysis(String detectorId, Optional adTask, User user, ActionListener listener) { if (!adTask.isPresent()) { listener.onFailure(new ResourceNotFoundException(detectorId, "Detector not started")); return; @@ -1237,7 +745,7 @@ private void stopHistoricalAnalysis( } String taskId = adTask.get().getTaskId(); - DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] dataNodes = hashRing.getNodesWithSameLocalVersion(); String userName = user == null ? null : user.getName(); ADCancelTaskRequest cancelTaskRequest = new ADCancelTaskRequest(detectorId, taskId, userName, dataNodes); @@ -1245,66 +753,22 @@ private void stopHistoricalAnalysis( .execute( ADCancelTaskAction.INSTANCE, cancelTaskRequest, - ActionListener - .wrap(response -> { listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); }, e -> { - logger.error("Failed to cancel AD task " + taskId + ", detector id: " + detectorId, e); - listener.onFailure(e); - }) + ActionListener.wrap(response -> { listener.onResponse(new JobResponse(taskId)); }, e -> { + logger.error("Failed to cancel AD task " + taskId + ", detector id: " + detectorId, e); + listener.onFailure(e); + }) ); } - private boolean lastUpdateTimeOfHistoricalTaskExpired(ADTask adTask) { + private boolean lastUpdateTimeOfHistoricalTaskExpired(TimeSeriesTask adTask) { // Wait at least 10 seconds. Piece interval seconds is dynamic setting, user could change it to a smaller value. int waitingTime = Math.max(2 * pieceIntervalSeconds, 10); return adTask.getLastUpdateTime().plus(waitingTime, ChronoUnit.SECONDS).isBefore(Instant.now()); } - private void resetTaskStateAsStopped( - ADTask adTask, - ExecutorFunction function, - TransportService transportService, - ActionListener listener - ) { - cleanDetectorCache(adTask, transportService, () -> { - String taskId = adTask.getTaskId(); - Map updatedFields = ImmutableMap.of(STATE_FIELD, ADTaskState.STOPPED.name()); - updateADTask(taskId, updatedFields, ActionListener.wrap(r -> { - adTask.setState(ADTaskState.STOPPED.name()); - if (function != null) { - function.execute(); - } - // For realtime anomaly detection, we only create detector level task, no entity level realtime task. - if (ADTaskType.HISTORICAL_HC_DETECTOR.name().equals(adTask.getTaskType())) { - // Reset running entity tasks as STOPPED - resetEntityTasksAsStopped(taskId); - } - }, e -> { - logger.error("Failed to update task state as STOPPED for task " + taskId, e); - listener.onFailure(e); - })); - }, listener); - } - - private void resetEntityTasksAsStopped(String detectorTaskId) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); - query.filter(new TermQueryBuilder(TASK_TYPE_FIELD, ADTaskType.HISTORICAL_HC_ENTITY.name())); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", STATE_FIELD, ADTaskState.STOPPED.name()); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (isNullOrEmpty(bulkFailures)) { - logger.debug("Updated {} child entity tasks state for detector task {}", r.getUpdated(), detectorTaskId); - } else { - logger.error("Failed to update child entity task's state for detector task {} ", detectorTaskId); - } - }, e -> logger.error("Exception happened when update child entity task's state for detector task " + detectorTaskId, e))); + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + return ADTaskType.AD_HISTORICAL_HC_DETECTOR.name().equals(task.getTaskType()); } /** @@ -1322,18 +786,19 @@ private void resetEntityTasksAsStopped(String detectorTaskId) { * @param listener action listener * @param response type of listener */ - public void cleanDetectorCache( - ADTask adTask, + @Override + public void cleanConfigCache( + TimeSeriesTask adTask, TransportService transportService, ExecutorFunction function, ActionListener listener ) { String coordinatingNode = adTask.getCoordinatingNode(); - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); try { forwardADTaskToCoordinatingNode( - adTask, + (ADTask) adTask, ADTaskAction.CLEAN_CACHE, transportService, ActionListener.wrap(r -> { function.execute(); }, e -> { @@ -1357,9 +822,9 @@ public void cleanDetectorCache( } protected void cleanDetectorCache(ADTask adTask, TransportService transportService, ExecutorFunction function) { - String detectorId = adTask.getId(); + String detectorId = adTask.getConfigId(); String taskId = adTask.getTaskId(); - cleanDetectorCache( + cleanConfigCache( adTask, transportService, function, @@ -1411,9 +876,9 @@ public void getLatestHistoricalTaskProfile( * @param listener action listener */ private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener) { - String detectorId = adDetectorLevelTask.getId(); + String detectorId = adDetectorLevelTask.getConfigId(); - hashRing.getAllEligibleDataNodesWithKnownAdVersion(dataNodes -> { + hashRing.getAllEligibleDataNodesWithKnownVersion(dataNodes -> { ADTaskProfileRequest adTaskProfileRequest = new ADTaskProfileRequest(detectorId, dataNodes); client.execute(ADTaskProfileAction.INSTANCE, adTaskProfileRequest, ActionListener.wrap(response -> { if (response.hasFailures()) { @@ -1460,73 +925,26 @@ private void getADTaskProfile(ADTask adDetectorLevelTask, ActionListener listener - ) { - UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); - updateByQueryRequest.indices(DETECTION_STATE_INDEX); - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detector.getId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(getADTaskTypes(detectionDateRange, true)))); - updateByQueryRequest.setQuery(query); - updateByQueryRequest.setRefresh(true); - String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", IS_LATEST_FIELD, false); - updateByQueryRequest.setScript(new Script(script)); - - client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { - List bulkFailures = r.getBulkFailures(); - if (bulkFailures.isEmpty()) { - // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job - // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct - // coordinating node once realtime job starts. - // For historical analysis, this method will be called on coordinating node, so we can set coordinating - // node as local node. - String coordinatingNode = detectionDateRange == null ? null : clusterService.localNode().getId(); - createNewADTask(detector, detectionDateRange, user, coordinatingNode, listener); - } else { - logger.error("Failed to update old task's state for detector: {}, response: {} ", detector.getId(), r.toString()); - listener.onFailure(bulkFailures.get(0).getCause()); - } - }, e -> { - logger.error("Failed to reset old tasks as not latest for detector " + detector.getId(), e); - listener.onFailure(e); - })); - } - - private void createNewADTask( - AnomalyDetector detector, + @Override + protected void createNewTask( + Config config, DateRange detectionDateRange, User user, String coordinatingNode, - ActionListener listener + ActionListener listener ) { String userName = user == null ? null : user.getName(); Instant now = Instant.now(); - String taskType = getADTaskType(detector, detectionDateRange).name(); + String taskType = getTaskType(config, detectionDateRange).name(); ADTask adTask = new ADTask.Builder() - .detectorId(detector.getId()) - .detector(detector) + .detectorId(config.getId()) + .detector((AnomalyDetector) config) .isLatest(true) .taskType(taskType) .executionStartTime(now) .taskProgress(0.0f) .initProgress(0.0f) - .state(ADTaskState.CREATED.name()) + .state(TaskState.CREATED.name()) .lastUpdateTime(now) .startedBy(userName) .coordinatingNode(coordinatingNode) @@ -1534,57 +952,33 @@ private void createNewADTask( .user(user) .build(); - createADTaskDirectly( + createTaskDirectly( adTask, - r -> onIndexADTaskResponse( + r -> onIndexConfigTaskResponse( r, adTask, - (response, delegatedListener) -> cleanOldAdTaskDocs(response, adTask, delegatedListener), + (response, delegatedListener) -> cleanOldConfigTaskDocs(response, adTask, delegatedListener), listener ), listener ); } - /** - * Create AD task directly without checking index exists of not. - * [Important!] Make sure listener returns in function - * - * @param adTask AD task - * @param function consumer function - * @param listener action listener - * @param action listener response type - */ - public void createADTaskDirectly(ADTask adTask, Consumer function, ActionListener listener) { - IndexRequest request = new IndexRequest(DETECTION_STATE_INDEX); - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - request - .source(adTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { - logger.error("Failed to create AD task for detector " + adTask.getId(), e); - listener.onFailure(e); - })); - } catch (Exception e) { - logger.error("Failed to create AD task for detector " + adTask.getId(), e); - listener.onFailure(e); - } - } - - private void onIndexADTaskResponse( + @Override + protected void onIndexConfigTaskResponse( IndexResponse response, ADTask adTask, - BiConsumer> function, - ActionListener listener + BiConsumer> function, + ActionListener listener ) { if (response == null || response.getResult() != CREATED) { - String errorMsg = getShardsFailure(response); + String errorMsg = ExceptionUtil.getShardsFailure(response); listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); return; } adTask.setTaskId(response.getId()); - ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { - handleADTaskException(adTask, e); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(adTask, e); if (e instanceof DuplicateTaskException) { listener.onFailure(new OpenSearchStatusException(DETECTOR_IS_RUNNING, RestStatus.BAD_REQUEST)); } else { @@ -1593,17 +987,17 @@ private void onIndexADTaskResponse( // ADTaskManager#initRealtimeTaskCacheAndCleanupStaleCache for details. Here the // realtime task cache not inited yet when create AD task, so no need to cleanup. if (adTask.isHistoricalTask()) { - adTaskCacheManager.removeHistoricalTaskCache(adTask.getId()); + taskCacheManager.removeHistoricalTaskCache(adTask.getConfigId()); } listener.onFailure(e); } }); try { - // Put detector id in cache. If detector id already in cache, will throw + // Put config id in cache. If config id already in cache, will throw // DuplicateTaskException. This is to solve race condition when user send - // multiple start request for one historical detector. + // multiple start request for one historical run. if (adTask.isHistoricalTask()) { - adTaskCacheManager.add(adTask.getId(), adTask); + taskCacheManager.add(adTask.getConfigId(), adTask); } } catch (Exception e) { delegatedListener.onFailure(e); @@ -1614,253 +1008,23 @@ private void onIndexADTaskResponse( } } - private void cleanOldAdTaskDocs(IndexResponse response, ADTask adTask, ActionListener delegatedListener) { - BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, adTask.getId())); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, false)); - - if (adTask.isHistoricalTask()) { - // If historical task, only delete detector level task. It may take longer time to delete entity tasks. - // We will delete child task (entity task) of detector level task in hourly cron job. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - } else { - // We don't have entity level task for realtime detection, so will delete all tasks. - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(REALTIME_TASK_TYPES))); - } - - SearchRequest searchRequest = new SearchRequest(); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder - .query(query) - .sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC) - // Search query "from" starts from 0. - .from(maxOldAdTaskDocsPerDetector) - .size(MAX_OLD_AD_TASK_DOCS); - searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); - String detectorId = adTask.getId(); - - deleteTaskDocs(detectorId, searchRequest, () -> { - if (adTask.isHistoricalTask()) { - // run batch result action for historical detection - runBatchResultAction(response, adTask, delegatedListener); - } else { - // return response directly for realtime detection - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); - delegatedListener.onResponse(anomalyDetectorJobResponse); - } - }, delegatedListener); - } - - protected void deleteTaskDocs( - String detectorId, - SearchRequest searchRequest, - ExecutorFunction function, - ActionListener listener - ) { - ActionListener searchListener = ActionListener.wrap(r -> { - Iterator iterator = r.getHits().iterator(); - if (iterator.hasNext()) { - BulkRequest bulkRequest = new BulkRequest(); - while (iterator.hasNext()) { - SearchHit searchHit = iterator.next(); - try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - ADTask adTask = ADTask.parse(parser, searchHit.getId()); - logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getId()); - bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); - } catch (Exception e) { - listener.onFailure(e); - } - } - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - logger.info("Old AD tasks deleted for detector {}", detectorId); - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); - // add deleted task in cache and delete its child tasks and AD results - adTaskCacheManager.addDeletedDetectorTask(bulkItemResponse.getId()); - } - } - } - // delete child tasks and AD results of this task - cleanChildTasksAndADResultsOfDeletedTask(); - - function.execute(); - }, e -> { - logger.warn("Failed to clean AD tasks for detector " + detectorId, e); - listener.onFailure(e); - })); - } else { - function.execute(); - } - }, e -> { - if (e instanceof IndexNotFoundException) { - function.execute(); - } else { - listener.onFailure(e); - } - }); - - client.search(searchRequest, searchListener); - } - - /** - * Poll deleted detector task from cache and delete its child tasks and AD results. - */ - public void cleanChildTasksAndADResultsOfDeletedTask() { - if (!adTaskCacheManager.hasDeletedDetectorTask()) { - return; - } - threadPool.schedule(() -> { - String taskId = adTaskCacheManager.pollDeletedDetectorTask(); - if (taskId == null) { - return; - } - DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(ALL_AD_RESULTS_INDEX_PATTERN); - deleteADResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); - client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(res -> { - logger.debug("Successfully deleted AD results of task " + taskId); - DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(DETECTION_STATE_INDEX); - deleteChildTasksRequest.setQuery(new TermsQueryBuilder(PARENT_TASK_ID_FIELD, taskId)); - - client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { - logger.debug("Successfully deleted child tasks of task " + taskId); - cleanChildTasksAndADResultsOfDeletedTask(); - }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); - }, ex -> { logger.error("Failed to delete AD results for task " + taskId, ex); })); - }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); - } - - private void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { + @Override + protected void runBatchResultAction(IndexResponse response, ADTask adTask, ActionListener listener) { client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { String remoteOrLocal = r.isRunTaskRemotely() ? "remote" : "local"; logger .info( "AD task {} of detector {} dispatched to {} node {}", adTask.getTaskId(), - adTask.getId(), + adTask.getConfigId(), remoteOrLocal, r.getNodeId() ); - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse( - response.getId(), - response.getVersion(), - response.getSeqNo(), - response.getPrimaryTerm(), - RestStatus.OK - ); + JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); listener.onResponse(anomalyDetectorJobResponse); }, e -> listener.onFailure(e))); } - /** - * Handle exceptions for AD task. Update task state and record error message. - * - * @param adTask AD task - * @param e exception - */ - public void handleADTaskException(ADTask adTask, Exception e) { - // TODO: handle timeout exception - String state = ADTaskState.FAILED.name(); - Map updatedFields = new HashMap<>(); - if (e instanceof DuplicateTaskException) { - // If user send multiple start detector request, we will meet race condition. - // Cache manager will put first request in cache and throw DuplicateTaskException - // for the second request. We will delete the second task. - logger - .warn( - "There is already one running task for detector, detectorId:" - + adTask.getId() - + ". Will delete task " - + adTask.getTaskId() - ); - deleteADTask(adTask.getTaskId()); - return; - } - if (e instanceof TaskCancelledException) { - logger.info("AD task cancelled, taskId: {}, detectorId: {}", adTask.getTaskId(), adTask.getId()); - state = ADTaskState.STOPPED.name(); - String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); - if (stoppedBy != null) { - updatedFields.put(STOPPED_BY_FIELD, stoppedBy); - } - } else { - logger.error("Failed to execute AD batch task, task id: " + adTask.getTaskId() + ", detector id: " + adTask.getId(), e); - } - updatedFields.put(ERROR_FIELD, getErrorMessage(e)); - updatedFields.put(STATE_FIELD, state); - updatedFields.put(EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); - updateADTask(adTask.getTaskId(), updatedFields); - } - - /** - * Update AD task with specific fields. - * - * @param taskId AD task id - * @param updatedFields updated fields, key: filed name, value: new value - */ - public void updateADTask(String taskId, Map updatedFields) { - updateADTask(taskId, updatedFields, ActionListener.wrap(response -> { - if (response.status() == RestStatus.OK) { - logger.debug("Updated AD task successfully: {}, task id: {}", response.status(), taskId); - } else { - logger.error("Failed to update AD task {}, status: {}", taskId, response.status()); - } - }, e -> { logger.error("Failed to update task: " + taskId, e); })); - } - - /** - * Update AD task for specific fields. - * - * @param taskId task id - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateADTask(String taskId, Map updatedFields, ActionListener listener) { - UpdateRequest updateRequest = new UpdateRequest(DETECTION_STATE_INDEX, taskId); - Map updatedContent = new HashMap<>(); - updatedContent.putAll(updatedFields); - updatedContent.put(LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); - updateRequest.doc(updatedContent); - updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.update(updateRequest, listener); - } - - /** - * Delete AD task with task id. - * - * @param taskId AD task id - */ - public void deleteADTask(String taskId) { - deleteADTask( - taskId, - ActionListener - .wrap( - r -> { logger.info("Deleted AD task {} with status: {}", taskId, r.status()); }, - e -> { logger.error("Failed to delete AD task " + taskId, e); } - ) - ); - } - - /** - * Delete AD task with task id. - * - * @param taskId AD task id - * @param listener action listener - */ - public void deleteADTask(String taskId, ActionListener listener) { - DeleteRequest deleteRequest = new DeleteRequest(DETECTION_STATE_INDEX, taskId); - client.delete(deleteRequest, listener); - } - /** * Cancel running task by detector id. * @@ -1871,7 +1035,7 @@ public void deleteADTask(String taskId, ActionListener listener) * @return AD task cancellation state */ public ADTaskCancellationState cancelLocalTaskByDetectorId(String detectorId, String detectorTaskId, String reason, String userName) { - ADTaskCancellationState cancellationState = adTaskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); + ADTaskCancellationState cancellationState = taskCacheManager.cancelByDetectorId(detectorId, detectorTaskId, reason, userName); logger .debug( "Cancelled AD task for detector: {}, state: {}, cancelled by: {}, reason: {}", @@ -1932,7 +1096,7 @@ private void deleteADResultOfDetector(String detectorId) { ActionListener .wrap(response -> { logger.debug("Successfully deleted AD results of detector " + detectorId); }, exception -> { logger.error("Failed to delete AD results of detector " + detectorId, exception); - adTaskCacheManager.addDeletedDetector(detectorId); + taskCacheManager.addDeletedDetector(detectorId); }) ); } @@ -1941,145 +1105,12 @@ private void deleteADResultOfDetector(String detectorId) { * Clean AD results of deleted detector. */ public void cleanADResultOfDeletedDetector() { - String detectorId = adTaskCacheManager.pollDeletedDetector(); + String detectorId = taskCacheManager.pollDeletedDetector(); if (detectorId != null) { deleteADResultOfDetector(detectorId); } } - /** - * Update latest AD task of detector. - * - * @param detectorId detector id - * @param taskTypes task types - * @param updatedFields updated fields, key: filed name, value: new value - * @param listener action listener - */ - public void updateLatestADTask( - String detectorId, - List taskTypes, - Map updatedFields, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, taskTypes, (adTask) -> { - if (adTask.isPresent()) { - updateADTask(adTask.get().getTaskId(), updatedFields, listener); - } else { - listener.onFailure(new ResourceNotFoundException(detectorId, CAN_NOT_FIND_LATEST_TASK)); - } - }, null, false, listener); - } - - /** - * Update latest realtime task. - * - * @param detectorId detector id - * @param state task state - * @param error error - * @param transportService transport service - * @param listener action listener - */ - public void stopLatestRealtimeTask( - String detectorId, - ADTaskState state, - Exception error, - TransportService transportService, - ActionListener listener - ) { - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTask) -> { - if (adTask.isPresent() && !adTask.get().isDone()) { - Map updatedFields = new HashMap<>(); - updatedFields.put(ADTask.STATE_FIELD, state.name()); - if (error != null) { - updatedFields.put(ADTask.ERROR_FIELD, error.getMessage()); - } - ExecutorFunction function = () -> updateADTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { - if (error == null) { - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); - } else { - listener.onFailure(error); - } - }, e -> { listener.onFailure(e); })); - - String coordinatingNode = adTask.get().getCoordinatingNode(); - if (coordinatingNode != null && transportService != null) { - cleanDetectorCache(adTask.get(), transportService, function, listener); - } else { - function.execute(); - } - } else { - listener.onFailure(new OpenSearchStatusException("Anomaly detector job is already stopped: " + detectorId, RestStatus.OK)); - } - }, null, false, listener); - } - - /** - * Update realtime task cache on realtime detector's coordinating node. - * - * @param detectorId detector id - * @param state new state - * @param rcfTotalUpdates rcf total updates - * @param detectorIntervalInMinutes detector interval in minutes - * @param error error - * @param listener action listener - */ - public void updateLatestRealtimeTaskOnCoordinatingNode( - String detectorId, - String state, - Long rcfTotalUpdates, - Long detectorIntervalInMinutes, - String error, - ActionListener listener - ) { - Float initProgress = null; - String newState = null; - // calculate init progress and task state with RCF total updates - if (detectorIntervalInMinutes != null && rcfTotalUpdates != null) { - newState = ADTaskState.INIT.name(); - if (rcfTotalUpdates < NUM_MIN_SAMPLES) { - initProgress = (float) rcfTotalUpdates / NUM_MIN_SAMPLES; - } else { - newState = ADTaskState.RUNNING.name(); - initProgress = 1.0f; - } - } - // Check if new state is not null and override state calculated with rcf total updates - if (state != null) { - newState = state; - } - - error = Optional.ofNullable(error).orElse(""); - if (!adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId, newState, initProgress, error)) { - // If task not changed, no need to update, just return - listener.onResponse(null); - return; - } - Map updatedFields = new HashMap<>(); - updatedFields.put(COORDINATING_NODE_FIELD, clusterService.localNode().getId()); - if (initProgress != null) { - updatedFields.put(INIT_PROGRESS_FIELD, initProgress); - updatedFields.put(ESTIMATED_MINUTES_LEFT_FIELD, Math.max(0, NUM_MIN_SAMPLES - rcfTotalUpdates) * detectorIntervalInMinutes); - } - if (newState != null) { - updatedFields.put(STATE_FIELD, newState); - } - if (error != null) { - updatedFields.put(ERROR_FIELD, error); - } - Float finalInitProgress = initProgress; - // Variable used in lambda expression should be final or effectively final - String finalError = error; - String finalNewState = newState; - updateLatestADTask(detectorId, ADTaskType.REALTIME_TASK_TYPES, updatedFields, ActionListener.wrap(r -> { - logger.debug("Updated latest realtime AD task successfully for detector {}", detectorId); - adTaskCacheManager.updateRealtimeTaskCache(detectorId, finalNewState, finalInitProgress, finalError); - listener.onResponse(r); - }, e -> { - logger.error("Failed to update realtime task for detector " + detectorId, e); - listener.onFailure(e); - })); - } - /** * Init realtime task cache and clean up realtime task cache on old coordinating node. Realtime AD * depends on job scheduler to choose node (job coordinating node) to run AD job. Nodes have primary @@ -2097,29 +1128,31 @@ public void updateLatestRealtimeTaskOnCoordinatingNode( * @param transportService transport service * @param listener listener */ - public void initRealtimeTaskCacheAndCleanupStaleCache( + @Override + public void initCacheWithCleanupIfRequired( String detectorId, - AnomalyDetector detector, + Config config, TransportService transportService, ActionListener listener ) { try { - if (adTaskCacheManager.getRealtimeTaskCache(detectorId) != null) { + if (taskCacheManager.getRealtimeTaskCache(detectorId) != null) { listener.onResponse(false); return; } - getAndExecuteOnLatestDetectorLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { + AnomalyDetector detector = (AnomalyDetector) config; + getAndExecuteOnLatestConfigLevelTask(detectorId, REALTIME_TASK_TYPES, (adTaskOptional) -> { if (!adTaskOptional.isPresent()) { logger.debug("Can't find realtime task for detector {}, init realtime task cache directly", detectorId); - ExecutorFunction function = () -> createNewADTask( + ExecutorFunction function = () -> createNewTask( detector, null, detector.getUser(), clusterService.localNode().getId(), ActionListener.wrap(r -> { logger.info("Recreate realtime task successfully for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, e -> { logger.error("Failed to recreate realtime task for detector " + detectorId, e); @@ -2141,19 +1174,19 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( localNodeId, detectorId ); - cleanDetectorCache(adTask, transportService, () -> { + cleanConfigCache(adTask, transportService, () -> { logger .info( "Realtime task cache cleaned on old coordinating node {} for detector {}", oldCoordinatingNode, detectorId ); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); }, listener); } else { logger.info("Init realtime task cache for detector {}", detectorId); - adTaskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); + taskCacheManager.initRealtimeTaskCache(detectorId, detector.getIntervalInMilliseconds()); listener.onResponse(true); } }, transportService, false, listener); @@ -2164,16 +1197,16 @@ public void initRealtimeTaskCacheAndCleanupStaleCache( } private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener) { - if (detectionIndices.doesStateIndexExist()) { + if (indexManagement.doesStateIndexExist()) { function.execute(); } else { // If detection index doesn't exist, create index and execute function. - detectionIndices.initStateIndex(ActionListener.wrap(r -> { + indexManagement.initStateIndex(ActionListener.wrap(r -> { if (r.isAcknowledged()) { logger.info("Created {} with mappings.", DETECTION_STATE_INDEX); function.execute(); } else { - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); logger.warn(error); listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); } @@ -2188,14 +1221,6 @@ private void recreateRealtimeTask(ExecutorFunction function, ActionListener listener + ActionListener listener ) { try { ADTaskAction action = getAdEntityTaskAction(adTask, exception); @@ -2253,7 +1278,7 @@ private void entityTaskDone( private ADTaskAction getAdEntityTaskAction(ADTask adTask, Exception exception) { ADTaskAction action = ADTaskAction.NEXT_ENTITY; if (exception != null) { - adTask.setError(getErrorMessage(exception)); + adTask.setError(ExceptionUtil.getErrorMessage(exception)); if (exception instanceof LimitExceededException && isRetryableError(exception.getMessage())) { action = ADTaskAction.PUSH_BACK_ENTITY; } else if (exception instanceof TaskCancelledException || exception instanceof EndRunException) { @@ -2289,14 +1314,14 @@ public boolean isRetryableError(String error) { * @param state AD task state * @param listener action listener */ - public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListener listener) { - String detectorId = adTask.getId(); + public void setHCDetectorTaskDone(ADTask adTask, TaskState state, ActionListener listener) { + String detectorId = adTask.getConfigId(); String taskId = adTask.isEntityTask() ? adTask.getParentTaskId() : adTask.getTaskId(); String detectorTaskId = adTask.getDetectorLevelTaskId(); ActionListener wrappedListener = ActionListener.wrap(response -> { logger.info("Historical HC detector done with state: {}. Remove from cache, detector id:{}", state.name(), detectorId); - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }, e -> { // HC detector task may fail to update as FINISHED for some edge case if failed to get updating semaphore. // Will reset task state when get detector with task or maintain tasks in hourly cron. @@ -2305,15 +1330,15 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen } else { logger.error("Failed to update task: " + taskId, e); } - adTaskCacheManager.removeHistoricalTaskCache(detectorId); + taskCacheManager.removeHistoricalTaskCache(detectorId); }); long timeoutInMillis = 2000;// wait for 2 seconds to acquire updating HC detector task semaphore - if (state == ADTaskState.FINISHED) { - this.countEntityTasksByState(detectorTaskId, ImmutableList.of(ADTaskState.FINISHED), ActionListener.wrap(r -> { - logger.info("number of finished entity tasks: {}, for detector {}", r, adTask.getId()); + if (state == TaskState.FINISHED) { + this.countEntityTasksByState(detectorTaskId, ImmutableList.of(TaskState.FINISHED), ActionListener.wrap(r -> { + logger.info("number of finished entity tasks: {}, for detector {}", r, adTask.getConfigId()); // Set task as FAILED if no finished entity task; otherwise set as FINISHED - ADTaskState hcDetectorTaskState = r == 0 ? ADTaskState.FAILED : ADTaskState.FINISHED; + TaskState hcDetectorTaskState = r == 0 ? TaskState.FAILED : TaskState.FINISHED; // execute in AD batch task thread pool in case waiting for semaphore waste any shared OpenSearch thread pool threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { updateADHCDetectorTask( @@ -2321,11 +1346,11 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, hcDetectorTaskState.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2335,20 +1360,20 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen }, e -> { logger.error("Failed to get finished entity tasks", e); - String errorMessage = getErrorMessage(e); + String errorMessage = ExceptionUtil.getErrorMessage(e); threadPool.executor(AD_BATCH_TASK_THREAD_POOL_NAME).execute(() -> { updateADHCDetectorTask( detectorId, taskId, ImmutableMap .of( - STATE_FIELD, - ADTaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. - TASK_PROGRESS_FIELD, + TimeSeriesTask.STATE_FIELD, + TaskState.FAILED.name(),// set as FAILED if fail to get finished entity tasks. + TimeSeriesTask.TASK_PROGRESS_FIELD, 1.0, - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, errorMessage, - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2363,11 +1388,11 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen taskId, ImmutableMap .of( - STATE_FIELD, + TimeSeriesTask.STATE_FIELD, state.name(), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError(), - EXECUTION_END_TIME_FIELD, + TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli() ), timeoutInMillis, @@ -2377,7 +1402,7 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen } - listener.onResponse(new AnomalyDetectorJobResponse(taskId, 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(taskId)); } /** @@ -2387,11 +1412,14 @@ public void setHCDetectorTaskDone(ADTask adTask, ADTaskState state, ActionListen * @param taskStates task states * @param listener action listener */ - public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { + public void countEntityTasksByState(String detectorTaskId, List taskStates, ActionListener listener) { BoolQueryBuilder queryBuilder = new BoolQueryBuilder(); - queryBuilder.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, detectorTaskId)); + queryBuilder.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, detectorTaskId)); if (taskStates != null && taskStates.size() > 0) { - queryBuilder.filter(new TermsQueryBuilder(STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList()))); + queryBuilder + .filter( + new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, taskStates.stream().map(s -> s.name()).collect(Collectors.toList())) + ); } SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); sourceBuilder.query(queryBuilder); @@ -2454,19 +1482,19 @@ private void updateADHCDetectorTask( ActionListener listener ) { try { - if (adTaskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { + if (taskCacheManager.tryAcquireTaskUpdatingSemaphore(detectorId, timeoutInMillis)) { try { - updateADTask( + updateTask( taskId, updatedFields, - ActionListener.runAfter(listener, () -> { adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) + ActionListener.runAfter(listener, () -> { taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); }) ); } catch (Exception e) { logger.error("Failed to update detector task " + taskId, e); - adTaskCacheManager.releaseTaskUpdatingSemaphore(detectorId); + taskCacheManager.releaseTaskUpdatingSemaphore(detectorId); listener.onFailure(e); } - } else if (!adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + } else if (!taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { // It's possible that AD task cache cleaned up by other task. Return null to avoid too many failure logs. logger.info("HC detector task cache does not exist, detectorId:{}, taskId:{}", detectorId, taskId); listener.onResponse(null); @@ -2491,12 +1519,8 @@ private void updateADHCDetectorTask( * @param transportService transport service * @param listener action listener */ - public void runNextEntityForHCADHistorical( - ADTask adTask, - TransportService transportService, - ActionListener listener - ) { - String detectorId = adTask.getId(); + public void runNextEntityForHCADHistorical(ADTask adTask, TransportService transportService, ActionListener listener) { + String detectorId = adTask.getConfigId(); int scaleDelta = scaleTaskSlots( adTask, transportService, @@ -2512,9 +1536,9 @@ public void runNextEntityForHCADHistorical( "Have scaled down task slots. Will not poll next entity for detector {}, task {}, task slots: {}", detectorId, adTask.getTaskId(), - adTaskCacheManager.getDetectorTaskSlots(detectorId) + taskCacheManager.getDetectorTaskSlots(detectorId) ); - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.ACCEPTED)); + listener.onResponse(new JobResponse(detectorId)); return; } client.execute(ADBatchAnomalyResultAction.INSTANCE, new ADBatchAnomalyResultRequest(adTask), ActionListener.wrap(r -> { @@ -2527,7 +1551,7 @@ public void runNextEntityForHCADHistorical( remoteOrLocal, r.getNodeId() ); - AnomalyDetectorJobResponse anomalyDetectorJobResponse = new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK); + JobResponse anomalyDetectorJobResponse = new JobResponse(detectorId); listener.onResponse(anomalyDetectorJobResponse); }, e -> { listener.onFailure(e); })); } @@ -2542,12 +1566,8 @@ public void runNextEntityForHCADHistorical( * @param scaleUpListener action listener * @return task slots scale delta */ - protected int scaleTaskSlots( - ADTask adTask, - TransportService transportService, - ActionListener scaleUpListener - ) { - String detectorId = adTask.getId(); + protected int scaleTaskSlots(ADTask adTask, TransportService transportService, ActionListener scaleUpListener) { + String detectorId = adTask.getConfigId(); if (!scaleEntityTaskLane.tryAcquire()) { logger.debug("Can't get scaleEntityTaskLane semaphore"); return 0; @@ -2555,9 +1575,9 @@ protected int scaleTaskSlots( try { int scaleDelta = detectorTaskSlotScaleDelta(detectorId); logger.debug("start to scale task slots for detector {} with delta {}", detectorId, scaleDelta); - if (adTaskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { + if (taskCacheManager.getAvailableNewEntityTaskLanes(detectorId) <= 0 && scaleDelta > 0) { // scale up to run more entities in parallel - Instant lastScaleEntityTaskLaneTime = adTaskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); + Instant lastScaleEntityTaskLaneTime = taskCacheManager.getLastScaleEntityTaskLaneTime(detectorId); if (lastScaleEntityTaskLaneTime == null) { logger.debug("lastScaleEntityTaskLaneTime is null for detector {}", detectorId); scaleEntityTaskLane.release(); @@ -2567,7 +1587,7 @@ protected int scaleTaskSlots( .plusMillis(SCALE_ENTITY_TASK_LANE_INTERVAL_IN_MILLIS) .isBefore(Instant.now()); if (lastScaleTimeExpired) { - adTaskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); + taskCacheManager.refreshLastScaleEntityTaskLaneTime(detectorId); logger.debug("Forward scale entity task lane request to lead node for detector {}", detectorId); forwardScaleTaskSlotRequestToLeadNode( adTask, @@ -2585,9 +1605,9 @@ protected int scaleTaskSlots( } } else { if (scaleDelta < 0) { // scale down to release task slots for other detectors - int runningEntityCount = adTaskCacheManager.getRunningEntityCount(detectorId) + adTaskCacheManager + int runningEntityCount = taskCacheManager.getRunningEntityCount(detectorId) + taskCacheManager .getTempEntityCount(detectorId); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDownDelta = Math.min(assignedTaskSlots - runningEntityCount, 0 - scaleDelta); logger .debug( @@ -2597,7 +1617,7 @@ protected int scaleTaskSlots( runningEntityCount, scaleDownDelta ); - adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); + taskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, scaleDownDelta); } scaleEntityTaskLane.release(); } @@ -2626,13 +1646,13 @@ protected int scaleTaskSlots( * @return detector task slots scale delta */ public int detectorTaskSlotScaleDelta(String detectorId) { - DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalAdVersion(); - int unfinishedEntities = adTaskCacheManager.getUnfinishedEntityCount(detectorId); + DiscoveryNode[] eligibleDataNodes = hashRing.getNodesWithSameLocalVersion(); + int unfinishedEntities = taskCacheManager.getUnfinishedEntityCount(detectorId); int totalTaskSlots = eligibleDataNodes.length * maxAdBatchTaskPerNode; int taskLaneLimit = Math.min(unfinishedEntities, Math.min(totalTaskSlots, maxRunningEntitiesPerDetector)); - adTaskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); + taskCacheManager.setDetectorTaskLaneLimit(detectorId, taskLaneLimit); - int assignedTaskSlots = adTaskCacheManager.getDetectorTaskSlots(detectorId); + int assignedTaskSlots = taskCacheManager.getDetectorTaskSlots(detectorId); int scaleDelta = taskLaneLimit - assignedTaskSlots; logger .debug( @@ -2656,8 +1676,8 @@ public int detectorTaskSlotScaleDelta(String detectorId) { * @return task progress */ public float hcDetectorProgress(String detectorId) { - int entityCount = adTaskCacheManager.getTopEntityCount(detectorId); - int leftEntities = adTaskCacheManager.getPendingEntityCount(detectorId) + adTaskCacheManager.getRunningEntityCount(detectorId); + int entityCount = taskCacheManager.getTopEntityCount(detectorId); + int leftEntities = taskCacheManager.getPendingEntityCount(detectorId) + taskCacheManager.getRunningEntityCount(detectorId); return 1 - (float) leftEntities / entityCount; } @@ -2667,23 +1687,23 @@ public float hcDetectorProgress(String detectorId) { * @return list of AD task profile */ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { - List tasksOfDetector = adTaskCacheManager.getTasksOfDetector(detectorId); + List tasksOfDetector = taskCacheManager.getTasksOfDetector(detectorId); ADTaskProfile detectorTaskProfile = null; String localNodeId = clusterService.localNode().getId(); - if (adTaskCacheManager.isHCTaskRunning(detectorId)) { + if (taskCacheManager.isHCTaskRunning(detectorId)) { detectorTaskProfile = new ADTaskProfile(); - if (adTaskCacheManager.isHCTaskCoordinatingNode(detectorId)) { + if (taskCacheManager.isHCTaskCoordinatingNode(detectorId)) { detectorTaskProfile.setNodeId(localNodeId); - detectorTaskProfile.setTaskId(adTaskCacheManager.getDetectorTaskId(detectorId)); - detectorTaskProfile.setDetectorTaskSlots(adTaskCacheManager.getDetectorTaskSlots(detectorId)); - detectorTaskProfile.setTotalEntitiesInited(adTaskCacheManager.topEntityInited(detectorId)); - detectorTaskProfile.setTotalEntitiesCount(adTaskCacheManager.getTopEntityCount(detectorId)); - detectorTaskProfile.setPendingEntitiesCount(adTaskCacheManager.getPendingEntityCount(detectorId)); - detectorTaskProfile.setRunningEntitiesCount(adTaskCacheManager.getRunningEntityCount(detectorId)); - detectorTaskProfile.setRunningEntities(adTaskCacheManager.getRunningEntities(detectorId)); - detectorTaskProfile.setAdTaskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()); - Instant latestHCTaskRunTime = adTaskCacheManager.getLatestHCTaskRunTime(detectorId); + detectorTaskProfile.setTaskId(taskCacheManager.getDetectorTaskId(detectorId)); + detectorTaskProfile.setDetectorTaskSlots(taskCacheManager.getDetectorTaskSlots(detectorId)); + detectorTaskProfile.setTotalEntitiesInited(taskCacheManager.topEntityInited(detectorId)); + detectorTaskProfile.setTotalEntitiesCount(taskCacheManager.getTopEntityCount(detectorId)); + detectorTaskProfile.setPendingEntitiesCount(taskCacheManager.getPendingEntityCount(detectorId)); + detectorTaskProfile.setRunningEntitiesCount(taskCacheManager.getRunningEntityCount(detectorId)); + detectorTaskProfile.setRunningEntities(taskCacheManager.getRunningEntities(detectorId)); + detectorTaskProfile.setAdTaskType(ADTaskType.AD_HISTORICAL_HC_DETECTOR.name()); + Instant latestHCTaskRunTime = taskCacheManager.getLatestHCTaskRunTime(detectorId); if (latestHCTaskRunTime != null) { detectorTaskProfile.setLatestHCTaskRunTime(latestHCTaskRunTime.toEpochMilli()); } @@ -2693,15 +1713,15 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { tasksOfDetector.forEach(taskId -> { ADEntityTaskProfile entityTaskProfile = new ADEntityTaskProfile( - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId, - adTaskCacheManager.getEntity(taskId), + taskCacheManager.getEntity(taskId), taskId, - ADTaskType.HISTORICAL_HC_ENTITY.name() + ADTaskType.AD_HISTORICAL_HC_ENTITY.name() ); entityTaskProfiles.add(entityTaskProfile); }); @@ -2718,12 +1738,12 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { if (tasksOfDetector.size() == 1) { String taskId = tasksOfDetector.get(0); detectorTaskProfile = new ADTaskProfile( - adTaskCacheManager.getDetectorTaskId(detectorId), - adTaskCacheManager.getShingle(taskId).size(), - adTaskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), - adTaskCacheManager.isThresholdModelTrained(taskId), - adTaskCacheManager.getThresholdModelTrainingDataSize(taskId), - adTaskCacheManager.getModelSize(taskId), + taskCacheManager.getDetectorTaskId(detectorId), + taskCacheManager.getShingle(taskId).size(), + taskCacheManager.getTRcfModel(taskId).getForest().getTotalUpdates(), + taskCacheManager.isThresholdModelTrained(taskId), + taskCacheManager.getThresholdModelTrainingDataSize(taskId), + taskCacheManager.getModelSize(taskId), localNodeId ); // Single-flow detector only has 1 task slot. @@ -2739,7 +1759,7 @@ public ADTaskProfile getLocalADTaskProfilesByDetectorId(String detectorId) { // Clean expired HC batch task run states as it may exists after HC historical analysis done if user cancel // before querying top entities done. We will clean it in hourly cron, check "maintainRunningHistoricalTasks" // method. Clean it up here when get task profile to release memory earlier. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); } ); logger.debug("Local AD task profile of detector {}: {}", detectorId, detectorTaskProfile); @@ -2790,35 +1810,20 @@ public synchronized void removeStaleRunningEntity( ADTask adTask, String entity, TransportService transportService, - ActionListener listener + ActionListener listener ) { - String detectorId = adTask.getId(); - boolean removed = adTaskCacheManager.removeRunningEntity(detectorId, entity); - if (removed && adTaskCacheManager.getPendingEntityCount(detectorId) > 0) { + String detectorId = adTask.getConfigId(); + boolean removed = taskCacheManager.removeRunningEntity(detectorId, entity); + if (removed && taskCacheManager.getPendingEntityCount(detectorId) > 0) { logger.debug("kick off next pending entities"); this.runNextEntityForHCADHistorical(adTask, transportService, listener); } else { - if (!adTaskCacheManager.hasEntity(detectorId)) { - setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + if (!taskCacheManager.hasEntity(detectorId)) { + setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } } } - public boolean skipUpdateHCRealtimeTask(String detectorId, String error) { - ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() == 1.0 - && Objects.equals(error, realtimeTaskCache.getError()); - } - - public boolean isHCRealtimeTaskStartInitializing(String detectorId) { - ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - return realtimeTaskCache != null - && realtimeTaskCache.getInitProgress() != null - && realtimeTaskCache.getInitProgress().floatValue() > 0; - } - public String convertEntityToString(ADTask adTask) { if (adTask == null || !adTask.isEntityTask()) { return null; @@ -2907,45 +1912,8 @@ public void getADTask(String taskId, ActionListener> listener) })); } - /** - * Set old AD task's latest flag as false. - * @param adTasks list of AD tasks - */ - public void resetLatestFlagAsFalse(List adTasks) { - if (adTasks == null || adTasks.size() == 0) { - return; - } - BulkRequest bulkRequest = new BulkRequest(); - adTasks.forEach(task -> { - try { - task.setLatest(false); - task.setLastUpdateTime(Instant.now()); - IndexRequest indexRequest = new IndexRequest(DETECTION_STATE_INDEX) - .id(task.getTaskId()) - .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); - bulkRequest.add(indexRequest); - } catch (Exception e) { - logger.error("Fail to parse task AD task to XContent, task id " + task.getTaskId(), e); - } - }); - - bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { - BulkItemResponse[] bulkItemResponses = res.getItems(); - if (bulkItemResponses != null && bulkItemResponses.length > 0) { - for (BulkItemResponse bulkItemResponse : bulkItemResponses) { - if (!bulkItemResponse.isFailed()) { - logger.warn("Reset AD tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); - } else { - logger.warn("Failed to reset AD tasks latest flag as false. Task id: " + bulkItemResponse.getId()); - } - } - } - }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); - } - public int getLocalAdUsedBatchTaskSlot() { - return adTaskCacheManager.getTotalBatchTaskCount(); + return taskCacheManager.getTotalBatchTaskCount(); } /** @@ -2971,7 +1939,7 @@ public int getLocalAdUsedBatchTaskSlot() { * @return assigned batch task slots */ public int getLocalAdAssignedBatchTaskSlot() { - return adTaskCacheManager.getTotalDetectorTaskSlots(); + return taskCacheManager.getTotalDetectorTaskSlots(); } // ========================================================= @@ -2991,23 +1959,23 @@ public int getLocalAdAssignedBatchTaskSlot() { */ public void maintainRunningHistoricalTasks(TransportService transportService, int size) { // Clean expired HC batch task run state cache. - adTaskCacheManager.cleanExpiredHCBatchTaskRunStates(); + taskCacheManager.cleanExpiredHCBatchTaskRunStates(); // Find owning node with highest AD version to make sure we only have 1 node maintain running historical tasks // and we use the latest logic. - Optional owningNode = hashRing.getOwningNodeWithHighestAdVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); + Optional owningNode = hashRing.getOwningNodeWithHighestVersion(AD_TASK_MAINTAINENCE_NODE_MODEL_ID); if (!owningNode.isPresent() || !clusterService.localNode().getId().equals(owningNode.get().getId())) { return; } logger.info("Start to maintain running historical tasks"); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); - query.filter(new TermsQueryBuilder(STATE_FIELD, NOT_ENDED_STATES)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // default maintain interval is 5 seconds, so maintain 10 tasks will take at least 50 seconds. - sourceBuilder.query(query).sort(LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, SortOrder.DESC).size(size); SearchRequest searchRequest = new SearchRequest(); searchRequest.source(sourceBuilder); searchRequest.indices(DETECTION_STATE_INDEX); @@ -3045,7 +2013,7 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue return; } threadPool.schedule(() -> { - resetHistoricalDetectorTaskState(ImmutableList.of(adTask), () -> { + resetHistoricalConfigTaskState(ImmutableList.of(adTask), () -> { logger.debug("Finished maintaining running historical task {}", adTask.getTaskId()); maintainRunningHistoricalTask(taskQueue, transportService); }, @@ -3053,7 +2021,12 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue ActionListener .wrap( r -> { - logger.debug("Reset historical task state done for task {}, detector {}", adTask.getTaskId(), adTask.getId()); + logger + .debug( + "Reset historical task state done for task {}, detector {}", + adTask.getTaskId(), + adTask.getConfigId() + ); }, e -> { logger.error("Failed to reset historical task state for task " + adTask.getTaskId(), e); } ) @@ -3062,20 +2035,36 @@ private void maintainRunningHistoricalTask(ConcurrentLinkedQueue taskQue } /** - * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime - * task cache directly if expired. + * Get list of task types. + * 1. If date range is null, will return all realtime task types + * 2. If date range is not null, will return all historical detector level tasks types + * if resetLatestTaskStateFlag is true; otherwise return all historical tasks types include + * HC entity level task type. + * @param dateRange detection date range + * @param resetLatestTaskStateFlag reset latest task state or not + * @return list of AD task types */ - public void maintainRunningRealtimeTasks() { - String[] detectorIds = adTaskCacheManager.getDetectorIdsInRealtimeTaskCache(); - if (detectorIds == null || detectorIds.length == 0) { - return; - } - for (int i = 0; i < detectorIds.length; i++) { - String detectorId = detectorIds[i]; - ADRealtimeTaskCache taskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); - if (taskCache != null && taskCache.expired()) { - adTaskCacheManager.removeRealtimeTaskCache(detectorId); + @Override + protected List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag) { + if (dateRange == null) { + return REALTIME_TASK_TYPES; + } else { + if (resetLatestTaskStateFlag) { + // return all task types include HC entity task to make sure we can reset all tasks latest flag + return ALL_HISTORICAL_TASK_TYPES; + } else { + return HISTORICAL_DETECTOR_TASK_TYPES; } } } + + @Override + protected List getRealTimeTaskTypes() { + return REALTIME_TASK_TYPES; + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ADTask::parse; + } } diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java index 84fe0c6fe..df6194353 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchAnomalyResultAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchAnomalyResultAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK; public static final ADBatchAnomalyResultAction INSTANCE = new ADBatchAnomalyResultAction(); private ADBatchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java index d865ec14c..84a22b261 100644 --- a/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADBatchTaskRemoteExecutionAction.java @@ -14,10 +14,10 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK_REMOTE; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADBatchTaskRemoteExecutionAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK_REMOTE; public static final ADBatchTaskRemoteExecutionAction INSTANCE = new ADBatchTaskRemoteExecutionAction(); private ADBatchTaskRemoteExecutionAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java index 31f20fa00..d20759f70 100644 --- a/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADCancelTaskAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.CANCEL_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADCancelTaskAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/" + CANCEL_TASK; public static final ADCancelTaskAction INSTANCE = new ADCancelTaskAction(); private ADCancelTaskAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java index 041d543b7..e54a4747e 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkAction.java @@ -12,18 +12,19 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; import org.opensearch.common.settings.Settings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.transport.TransportRequestOptions; -public class ADResultBulkAction extends ActionType { +public class ADResultBulkAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; public static final ADResultBulkAction INSTANCE = new ADResultBulkAction(); private ADResultBulkAction() { - super(NAME, ADResultBulkResponse::new); + super(NAME, ResultBulkResponse::new); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java index f5f361f69..0f8430a25 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkRequest.java @@ -12,73 +12,19 @@ package org.opensearch.ad.transport; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.action.ValidateActions; -import org.opensearch.ad.ratelimit.ResultWriteRequest; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.transport.ResultBulkRequest; -public class ADResultBulkRequest extends ActionRequest implements Writeable { - private final List anomalyResults; - static final String NO_REQUESTS_ADDED_ERR = "no requests added"; +public class ADResultBulkRequest extends ResultBulkRequest { public ADResultBulkRequest() { - anomalyResults = new ArrayList<>(); + super(); } public ADResultBulkRequest(StreamInput in) throws IOException { - super(in); - int size = in.readVInt(); - anomalyResults = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - anomalyResults.add(new ResultWriteRequest(in)); - } - } - - @Override - public ActionRequestValidationException validate() { - ActionRequestValidationException validationException = null; - if (anomalyResults.isEmpty()) { - validationException = ValidateActions.addValidationError(NO_REQUESTS_ADDED_ERR, validationException); - } - return validationException; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeVInt(anomalyResults.size()); - for (ResultWriteRequest result : anomalyResults) { - result.writeTo(out); - } - } - - /** - * - * @return all of the results to send - */ - public List getAnomalyResults() { - return anomalyResults; - } - - /** - * Add result to send - * @param resultWriteRequest The result write request - */ - public void add(ResultWriteRequest resultWriteRequest) { - anomalyResults.add(resultWriteRequest); - } - - /** - * - * @return total index requests - */ - public int numberOfActions() { - return anomalyResults.size(); + super(in, ADResultWriteRequest::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java index 9928dd1dc..6b1c94fb4 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADResultBulkTransportAction.java @@ -14,45 +14,32 @@ import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; import java.io.IOException; import java.util.List; -import java.util.Random; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; -import org.opensearch.action.bulk.BulkAction; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.util.BulkUtil; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.IndexingPressure; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class ADResultBulkTransportAction extends HandledTransportAction { +public class ADResultBulkTransportAction extends ResultBulkTransportAction { private static final Logger LOG = LogManager.getLogger(ADResultBulkTransportAction.class); - private IndexingPressure indexingPressure; - private final long primaryAndCoordinatingLimits; - private float softLimit; - private float hardLimit; - private String indexName; - private Client client; - private Random random; @Inject public ADResultBulkTransportAction( @@ -63,69 +50,51 @@ public ADResultBulkTransportAction( ClusterService clusterService, Client client ) { - super(ADResultBulkAction.NAME, transportService, actionFilters, ADResultBulkRequest::new, ThreadPool.Names.SAME); - this.indexingPressure = indexingPressure; - this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); - this.softLimit = AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings); - this.hardLimit = AD_INDEX_PRESSURE_HARD_LIMIT.get(settings); - this.indexName = ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; - this.client = client; + super( + ADResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + AD_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + AD_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, + ADResultBulkRequest::new + ); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); - // random seed is 42. Can be any number - this.random = new Random(42); } @Override - protected void doExecute(Task task, ADResultBulkRequest request, ActionListener listener) { - // Concurrent indexing memory limit = 10% of heap - // indexing pressure = indexing bytes / indexing limit - // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index - // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). - long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); - float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; - List results = request.getAnomalyResults(); - - if (results == null || results.size() < 1) { - listener.onResponse(new ADResultBulkResponse()); - } - + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ADResultBulkRequest request) { BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); if (indexingPressurePercent <= softLimit) { - for (ResultWriteRequest resultWriteRequest : results) { - addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getCustomResultIndex()); + for (ResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); } } else if (indexingPressurePercent <= hardLimit) { // exceed soft limit (60%) but smaller than hard limit (90%) float acceptProbability = 1 - indexingPressurePercent; - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority() || random.nextFloat() < acceptProbability) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } else { // if exceeding hard limit, only index non-zero grade or error result - for (ResultWriteRequest resultWriteRequest : results) { + for (ADResultWriteRequest resultWriteRequest : results) { AnomalyResult result = resultWriteRequest.getResult(); if (result.isHighPriority()) { - addResult(bulkRequest, result, resultWriteRequest.getCustomResultIndex()); + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); } } } - if (bulkRequest.numberOfActions() > 0) { - client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { - List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); - listener.onResponse(new ADResultBulkResponse(failedRequests)); - }, e -> { - LOG.error("Failed to bulk index AD result", e); - listener.onFailure(e); - })); - } else { - listener.onResponse(new ADResultBulkResponse()); - } + return bulkRequest; } private void addResult(BulkRequest bulkRequest, AnomalyResult result, String resultIndex) { diff --git a/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java new file mode 100644 index 000000000..916b9edfd --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/ADResultProcessor.java @@ -0,0 +1,485 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ad.transport; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ADResultProcessor extends + ResultProcessor { + private static final Logger LOG = LogManager.getLogger(ADResultProcessor.class); + + private final ADModelManager adModelManager; + + public ADResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + Stats timeSeriesStats, + ADTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + ADModelManager adModelManager + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + AD_MAX_ENTITIES_PER_QUERY, + AD_PAGE_SIZE, + AnalysisType.AD + ); + this.adModelManager = adModelManager; + } + + // For single stream detector + @Override + protected ActionListener onFeatureResponseForSingleStreamConfig( + String adID, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime + ) { + return ActionListener.wrap(featureOptional -> { + List featureInResponse = null; + AnomalyDetector detector = (AnomalyDetector) config; + if (featureOptional.getUnprocessedFeatures().isPresent()) { + featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); + } + + if (!featureOptional.getProcessedFeatures().isPresent()) { + + Optional exception = coldStartIfNoCheckPoint(detector); + if (exception.isPresent()) { + listener.onFailure(exception.get()); + return; + } + + if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + ResultResponse + .create( + new ArrayList(), + "No data in current detection window", + null, + null, + false, + transportResultResponseClazz + ) + ); + + } else { + LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); + listener + .onResponse( + ResultResponse + .create( + featureInResponse, + "No full shingle in current detection window", + null, + null, + false, + transportResultResponseClazz + ) + ); + } + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); + + RCFActionListener rcfListener = new RCFActionListener( + rcfModelId, + failure, + rcfNode.getId(), + detector, + listener, + featureInResponse, + adID + ); + + // The threshold for splitting RCF models in single-stream detectors. + // The smallest machine in the Amazon managed service has 1GB heap. + // With the setting, the desired model size there is of 2 MB. + // By default, we can have at most 5 features. Since the default shingle size + // is 8, we have at most 40 dimensions in RCF. In our current RCF setting, + // 30 trees, and bounding box cache ratio 0, 40 dimensions use 449KB. + // Users can increase the number of features to 10 and shingle size to 60, + // 30 trees, bounding box cache ratio 0, 600 dimensions use 1.8 MB. + // Since these sizes are smaller than the threshold 2 MB, we won't split models + // even in the smallest machine. + transportService + .sendRequest( + rcfNode, + RCFResultAction.NAME, + new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), + option, + new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) + ); + }, exception -> { handleQueryFailure(exception, listener, adID); }); + } + + // For single stream detector + class RCFActionListener implements ActionListener { + private String modelID; + private AtomicReference failure; + private String rcfNodeID; + private Config detector; + private ActionListener listener; + private List featureInResponse; + private final String adID; + + RCFActionListener( + String modelID, + AtomicReference failure, + String rcfNodeID, + Config detector, + ActionListener listener, + List features, + String adID + ) { + this.modelID = modelID; + this.failure = failure; + this.rcfNodeID = rcfNodeID; + this.detector = detector; + this.listener = listener; + this.featureInResponse = features; + this.adID = adID; + } + + @Override + public void onResponse(RCFResultResponse response) { + try { + nodeStateManager.resetBackpressureCounter(rcfNodeID, adID); + if (response != null) { + listener + .onResponse( + new AnomalyResultResponse( + response.getAnomalyGrade(), + response.getConfidence(), + response.getRCFScore(), + featureInResponse, + null, + response.getTotalUpdates(), + detector.getIntervalInMinutes(), + false, + response.getRelativeIndex(), + response.getAttribution(), + response.getPastValues(), + response.getExpectedValuesList(), + response.getLikelihoodOfValues(), + response.getThreshold() + ) + ); + } else { + LOG.warn(ResultProcessor.NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); + listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + + @Override + public void onFailure(Exception e) { + try { + handlePredictionFailure(e, adID, rcfNodeID, failure); + Exception exception = coldStartIfNoModel(failure, detector); + if (exception != null) { + listener.onFailure(exception); + } else { + listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); + } + } catch (Exception ex) { + LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); + ResultProcessor.handleExecuteException(ex, listener, adID); + } + } + } + + /** + * Verify failure of rcf or threshold models. If there is no model, trigger cold + * start. If there is an exception for the previous cold start of this detector, + * throw exception to the caller. + * + * @param failure object that may contain exceptions thrown + * @param detector detector object + * @return exception if AD job execution gets resource not found exception + * @throws Exception when the input failure is not a ResourceNotFoundException. + * List of exceptions we can throw + * 1. Exception from cold start: + * 1). InternalFailure due to + * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start + * 2). EndRunException with endNow equal to false + * a. training data not available + * b. cold start cannot succeed + * c. invalid training data + * 3) EndRunException with endNow equal to true + * a. invalid search query + * 2. LimitExceededException from one of RCF model node when the total size of the models + * is more than X% of heap memory. + * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by + * RCF/Threshold model node failing to get checkpoint to restore model before timeout. + */ + private Exception coldStartIfNoModel(AtomicReference failure, Config detector) throws Exception { + Exception exp = failure.get(); + if (exp == null) { + return null; + } + + // return exceptions like LimitExceededException to caller + if (!(exp instanceof ResourceNotFoundException)) { + return exp; + } + + // fetch previous cold start exception + String adID = detector.getId(); + final Optional previousException = nodeStateManager.fetchExceptionAndClear(adID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return exception; + } + } + LOG.info("Trigger cold start for {}", detector.getId()); + // only used in single-stream anomaly detector thus type cast + coldStart((AnomalyDetector) detector); + return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); + } + + // only used for single-stream anomaly detector + private void coldStart(AnomalyDetector detector) { + String detectorId = detector.getId(); + + // If last cold start is not finished, we don't trigger another one + if (nodeStateManager.isColdStartRunning(detectorId)) { + return; + } + + final Releasable coldStartFinishingCallback = nodeStateManager.markColdStartRunning(detectorId); + + ActionListener> listener = ActionListener.wrap(trainingData -> { + if (trainingData.isPresent()) { + double[][] dataPoints = trainingData.get(); + + ActionListener trainModelListener = ActionListener + .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { + if (exception instanceof TimeSeriesException) { + // e.g., partitioned model exceeds memory limit + nodeStateManager.setException(detectorId, exception); + } else if (exception instanceof IllegalArgumentException) { + // IllegalArgumentException due to invalid training data + nodeStateManager + .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); + } else if (exception instanceof OpenSearchTimeoutException) { + nodeStateManager + .setException( + detectorId, + new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) + ); + } else { + nodeStateManager + .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); + } + }); + + adModelManager + .trainModel( + detector, + dataPoints, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + trainModelListener, + false + ) + ); + } else { + nodeStateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); + } + }, exception -> { + if (exception instanceof OpenSearchTimeoutException) { + nodeStateManager + .setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); + } else if (exception instanceof TimeSeriesException) { + // e.g., Invalid search query + nodeStateManager.setException(detectorId, exception); + } else { + nodeStateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); + } + }); + + final ActionListener> listenerWithReleaseCallback = ActionListener + .runAfter(listener, coldStartFinishingCallback::close); + + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute( + () -> featureManager + .getColdStartData( + detector, + new ThreadedActionListener<>( + LOG, + threadPool, + TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, + listenerWithReleaseCallback, + false + ) + ) + ); + } + + /** + * Check if checkpoint for an detector exists or not. If not and previous + * run is not EndRunException whose endNow is true, trigger cold start. + * @param detector detector object + * @return previous cold start exception + */ + private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { + String detectorId = detector.getId(); + + Optional previousException = nodeStateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); + if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { + return previousException; + } + } + + nodeStateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { + if (!checkpointExists) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } + }, exception -> { + Throwable cause = ExceptionsHelper.unwrapCause(exception); + if (cause instanceof IndexNotFoundException) { + LOG.info("Trigger cold start for {}", detectorId); + coldStart(detector); + } else { + String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); + LOG.error(errorMsg, exception); + nodeStateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); + } + })); + + return previousException; + } + + @Override + protected void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof IndexNotFoundException && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { + // checkpoint index does not exist + // ResourceNotFoundException will trigger cold start later + failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); + } + super.findException(cause, adID, failure, nodeId); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java index f6f39ab85..987606b28 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesAction.java @@ -12,7 +12,7 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; /** * ADStatsNodesAction class @@ -20,7 +20,7 @@ public class ADStatsNodesAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "stats/nodes"; public static final ADStatsNodesAction INSTANCE = new ADStatsNodesAction(); /** diff --git a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java index 17a81da0a..970a42f77 100644 --- a/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsNodesTransportAction.java @@ -20,14 +20,14 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.stats.Stats; import org.opensearch.transport.TransportService; /** @@ -36,7 +36,7 @@ public class ADStatsNodesTransportAction extends TransportNodesAction { - private ADStats adStats; + private Stats adStats; private final JvmService jvmService; private final ADTaskManager adTaskManager; @@ -47,7 +47,7 @@ public class ADStatsNodesTransportAction extends * @param clusterService ClusterService * @param transportService TransportService * @param actionFilters Action Filters - * @param adStats ADStats object + * @param adStats TimeSeriesStats object * @param jvmService ES JVM Service * @param adTaskManager AD task manager */ @@ -57,7 +57,7 @@ public ADStatsNodesTransportAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - ADStats adStats, + Stats adStats, JvmService jvmService, ADTaskManager adTaskManager ) { diff --git a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java b/src/main/java/org/opensearch/ad/transport/ADStatsResponse.java similarity index 98% rename from src/main/java/org/opensearch/ad/stats/ADStatsResponse.java rename to src/main/java/org/opensearch/ad/transport/ADStatsResponse.java index f90e451f9..41877eb2d 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStatsResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ADStatsResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.ad.transport; import java.io.IOException; import java.util.Map; @@ -18,7 +18,6 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.commons.lang.builder.ToStringBuilder; import org.opensearch.ad.model.Mergeable; -import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java index f2b198d1c..f66d9e1ec 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileAction.java @@ -14,11 +14,11 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ADTaskProfileAction extends ActionType { - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/" + AD_TASK; public static final ADTaskProfileAction INSTANCE = new ADTaskProfileAction(); private ADTaskProfileAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java index 6902d6de8..4bfbf7ca3 100644 --- a/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ADTaskProfileTransportAction.java @@ -18,13 +18,13 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.transport.TransportService; public class ADTaskProfileTransportAction extends @@ -79,7 +79,7 @@ protected ADTaskProfileNodeResponse newNodeResponse(StreamInput in) throws IOExc @Override protected ADTaskProfileNodeResponse nodeOperation(ADTaskProfileNodeRequest request) { String remoteNodeId = request.getParentTask().getNodeId(); - Version remoteAdVersion = hashRing.getAdVersion(remoteNodeId); + Version remoteAdVersion = hashRing.getVersion(remoteNodeId); ADTaskProfile adTaskProfile = adTaskManager.getLocalADTaskProfilesByDetectorId(request.getId()); return new ADTaskProfileNodeResponse(clusterService.localNode(), adTaskProfile, remoteAdVersion); } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java index b11283181..b03180b70 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class AnomalyDetectorJobAction extends ActionType { +public class AnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/jobmanagement"; public static final AnomalyDetectorJobAction INSTANCE = new AnomalyDetectorJobAction(); private AnomalyDetectorJobAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java deleted file mode 100644 index 3a62315a6..000000000 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobRequest.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.io.IOException; - -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.timeseries.model.DateRange; - -public class AnomalyDetectorJobRequest extends ActionRequest { - - private String detectorID; - private DateRange detectionDateRange; - private boolean historical; - private long seqNo; - private long primaryTerm; - private String rawPath; - - public AnomalyDetectorJobRequest(StreamInput in) throws IOException { - super(in); - detectorID = in.readString(); - seqNo = in.readLong(); - primaryTerm = in.readLong(); - rawPath = in.readString(); - if (in.readBoolean()) { - detectionDateRange = new DateRange(in); - } - historical = in.readBoolean(); - } - - public AnomalyDetectorJobRequest(String detectorID, long seqNo, long primaryTerm, String rawPath) { - this(detectorID, null, false, seqNo, primaryTerm, rawPath); - } - - /** - * Constructor function. - * - * The detectionDateRange and historical boolean can be passed in individually. - * The historical flag is for stopping detector, the detectionDateRange is for - * starting detector. It's ok if historical is true but detectionDateRange is - * null. - * - * @param detectorID detector identifier - * @param detectionDateRange detection date range - * @param historical historical analysis or not - * @param seqNo seq no - * @param primaryTerm primary term - * @param rawPath raw request path - */ - public AnomalyDetectorJobRequest( - String detectorID, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath - ) { - super(); - this.detectorID = detectorID; - this.detectionDateRange = detectionDateRange; - this.historical = historical; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - this.rawPath = rawPath; - } - - public String getDetectorID() { - return detectorID; - } - - public DateRange getDetectionDateRange() { - return detectionDateRange; - } - - public long getSeqNo() { - return seqNo; - } - - public long getPrimaryTerm() { - return primaryTerm; - } - - public String getRawPath() { - return rawPath; - } - - public boolean isHistorical() { - return historical; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(detectorID); - out.writeLong(seqNo); - out.writeLong(primaryTerm); - out.writeString(rawPath); - if (detectionDateRange != null) { - out.writeBoolean(true); - detectionDateRange.writeTo(out); - } else { - out.writeBoolean(false); - } - out.writeBoolean(historical); - } - - @Override - public ActionRequestValidationException validate() { - return null; - } -} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java index 1f86cefbb..2926db720 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportAction.java @@ -13,47 +13,30 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_START_DETECTOR; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_STOP_DETECTOR; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; -import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.tasks.Task; -import org.opensearch.timeseries.model.DateRange; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.transport.BaseJobTransportAction; import org.opensearch.transport.TransportService; -public class AnomalyDetectorJobTransportAction extends HandledTransportAction { - private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); - - private final Client client; - private final ClusterService clusterService; - private final Settings settings; - private final ADIndexManagement anomalyDetectionIndices; - private final NamedXContentRegistry xContentRegistry; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; - private final TransportService transportService; - private final ExecuteADResultResponseRecorder recorder; - +public class AnomalyDetectorJobTransportAction extends + BaseJobTransportAction { @Inject public AnomalyDetectorJobTransportAction( TransportService transportService, @@ -61,94 +44,23 @@ public AnomalyDetectorJobTransportAction( Client client, ClusterService clusterService, Settings settings, - ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager, - ExecuteADResultResponseRecorder recorder - ) { - super(AnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); - this.transportService = transportService; - this.client = client; - this.clusterService = clusterService; - this.settings = settings; - this.anomalyDetectionIndices = anomalyDetectionIndices; - this.xContentRegistry = xContentRegistry; - this.adTaskManager = adTaskManager; - filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.recorder = recorder; - } - - @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener actionListener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); - boolean historical = request.isHistorical(); - long seqNo = request.getSeqNo(); - long primaryTerm = request.getPrimaryTerm(); - String rawPath = request.getRawPath(); - TimeValue requestTimeout = REQUEST_TIMEOUT.get(settings); - String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? FAIL_TO_START_DETECTOR : FAIL_TO_STOP_DETECTOR; - ActionListener listener = wrapRestActionListener(actionListener, errorMessage); - - // By the time request reaches here, the user permissions are validated by Security plugin. - User user = getUserContext(client); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - resolveUserAndExecute( - user, - detectorId, - filterByEnabled, - listener, - (anomalyDetector) -> executeDetector( - listener, - detectorId, - detectionDateRange, - historical, - seqNo, - primaryTerm, - rawPath, - requestTimeout, - user, - context - ), - client, - clusterService, - xContentRegistry - ); - } catch (Exception e) { - logger.error(e); - listener.onFailure(e); - } - } - - private void executeDetector( - ActionListener listener, - String detectorId, - DateRange detectionDateRange, - boolean historical, - long seqNo, - long primaryTerm, - String rawPath, - TimeValue requestTimeout, - User user, - ThreadContext.StoredContext context + ADIndexJobActionHandler adIndexJobActionHandler ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + super( + transportService, + actionFilters, client, - anomalyDetectionIndices, - detectorId, - seqNo, - primaryTerm, - requestTimeout, + clusterService, + settings, xContentRegistry, - transportService, - adTaskManager, - recorder + AD_FILTER_BY_BACKEND_ROLES, + AnomalyDetectorJobAction.NAME, + AD_REQUEST_TIMEOUT, + FAIL_TO_START_DETECTOR, + FAIL_TO_STOP_DETECTOR, + AnomalyDetector.class, + adIndexJobActionHandler ); - if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { - adTaskManager.startDetector(detectorId, detectionDateRange, handler, user, transportService, context, listener); - } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { - adTaskManager.stopDetector(detectorId, historical, handler, user, transportService, listener); - } } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java index d61bd5822..5f413c5b3 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class AnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/run"; public static final AnomalyResultAction INSTANCE = new AnomalyResultAction(); private AnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java index e6f788aeb..397271da0 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultRequest.java @@ -26,56 +26,24 @@ import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; -public class AnomalyResultRequest extends ActionRequest implements ToXContentObject { - private String adID; - // time range start and end. Unit: epoch milliseconds - private long start; - private long end; - +public class AnomalyResultRequest extends ResultRequest { public AnomalyResultRequest(StreamInput in) throws IOException { super(in); - adID = in.readString(); - start = in.readLong(); - end = in.readLong(); } public AnomalyResultRequest(String adID, long start, long end) { - super(); - this.adID = adID; - this.start = start; - this.end = end; - } - - public long getStart() { - return start; - } - - public long getEnd() { - return end; - } - - public String getAdID() { - return adID; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(adID); - out.writeLong(start); - out.writeLong(end); + super(adID, start, end); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { + if (Strings.isEmpty(configId)) { validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { @@ -90,7 +58,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(ADCommonName.ID_JSON_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.endObject(); diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java index 6f65fdb6d..83998c559 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultResponse.java @@ -17,6 +17,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -27,11 +28,11 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; -public class AnomalyResultResponse extends ActionResponse implements ToXContentObject { +public class AnomalyResultResponse extends ResultResponse { public static final String ANOMALY_GRADE_JSON_KEY = "anomalyGrade"; public static final String CONFIDENCE_JSON_KEY = "confidence"; public static final String ANOMALY_SCORE_JSON_KEY = "anomalyScore"; @@ -49,18 +50,13 @@ public class AnomalyResultResponse extends ActionResponse implements ToXContentO private Double anomalyGrade; private Double confidence; - private Double anomalyScore; - private String error; - private List features; - private Long rcfTotalUpdates; - private Long detectorIntervalInMinutes; - private Boolean isHCDetector; private Integer relativeIndex; private double[] relevantAttribution; private double[] pastValues; private double[][] expectedValuesList; private double[] likelihoodOfValues; private Double threshold; + protected Double anomalyScore; // used when returning an error/exception or empty result public AnomalyResultResponse( @@ -104,14 +100,10 @@ public AnomalyResultResponse( double[] likelihoodOfValues, Double threshold ) { + super(features, error, rcfTotalUpdates, detectorIntervalInMinutes, isHCDetector); this.anomalyGrade = anomalyGrade; this.confidence = confidence; this.anomalyScore = anomalyScore; - this.features = features; - this.error = error; - this.rcfTotalUpdates = rcfTotalUpdates; - this.detectorIntervalInMinutes = detectorIntervalInMinutes; - this.isHCDetector = isHCDetector; this.relativeIndex = relativeIndex; this.relevantAttribution = currentTimeAttribution; this.pastValues = pastValues; @@ -134,8 +126,8 @@ public AnomalyResultResponse(StreamInput in) throws IOException { // new field added since AD 1.1 // Only send AnomalyResultRequest to local node, no need to change this part for BWC rcfTotalUpdates = in.readOptionalLong(); - detectorIntervalInMinutes = in.readOptionalLong(); - isHCDetector = in.readOptionalBoolean(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); this.relativeIndex = in.readOptionalInt(); @@ -177,10 +169,6 @@ public double getAnomalyGrade() { return anomalyGrade; } - public List getFeatures() { - return features; - } - public double getConfidence() { return confidence; } @@ -189,22 +177,6 @@ public double getAnomalyScore() { return anomalyScore; } - public String getError() { - return error; - } - - public Long getRcfTotalUpdates() { - return rcfTotalUpdates; - } - - public Long getIntervalInMinutes() { - return detectorIntervalInMinutes; - } - - public Boolean isHCDetector() { - return isHCDetector; - } - public Integer getRelativeIndex() { return relativeIndex; } @@ -240,8 +212,8 @@ public void writeTo(StreamOutput out) throws IOException { } out.writeOptionalString(error); out.writeOptionalLong(rcfTotalUpdates); - out.writeOptionalLong(detectorIntervalInMinutes); - out.writeOptionalBoolean(isHCDetector); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); out.writeOptionalInt(relativeIndex); @@ -295,7 +267,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); - builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, detectorIntervalInMinutes); + builder.field(DETECTOR_INTERVAL_IN_MINUTES_JSON_KEY, configIntervalInMinutes); builder.field(RELATIVE_INDEX_FIELD_JSON_KEY, relativeIndex); builder.field(RELEVANT_ATTRIBUTION_FIELD_JSON_KEY, relevantAttribution); builder.field(PAST_VALUES_FIELD_JSON_KEY, pastValues); @@ -325,7 +297,7 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * * Convert AnomalyResultResponse to AnomalyResult * - * @param detectorId Detector Id + * @param configId Detector Id * @param dataStartInstant data start time * @param dataEndInstant data end time * @param executionStartInstant execution start time @@ -335,8 +307,9 @@ public static AnomalyResultResponse fromActionResponse(final ActionResponse acti * @param error Error * @return converted AnomalyResult */ - public AnomalyResult toAnomalyResult( - String detectorId, + @Override + public List toIndexableResults( + String configId, Instant dataStartInstant, Instant dataEndInstant, Instant executionStartInstant, @@ -347,30 +320,43 @@ public AnomalyResult toAnomalyResult( ) { // Detector interval in milliseconds long detectorIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); - return AnomalyResult - .fromRawTRCFResult( - detectorId, - detectorIntervalMilli, - null, // real time results have no task id - anomalyScore, - anomalyGrade, - confidence, - features, - dataStartInstant, - dataEndInstant, - executionStartInstant, - executionEndInstant, - error, - Optional.empty(), - user, - schemaVersion, - null, // single-stream real-time has no model id - relevantAttribution, - relativeIndex, - pastValues, - expectedValuesList, - likelihoodOfValues, - threshold + return Collections + .singletonList( + AnomalyResult + .fromRawTRCFResult( + configId, + detectorIntervalMilli, + null, // real time results have no task id + anomalyScore, + anomalyGrade, + confidence, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + relevantAttribution, + relativeIndex, + pastValues, + expectedValuesList, + likelihoodOfValues, + threshold + ) ); } + + @Override + public boolean shouldSave() { + // skipping writing to the result index if not necessary + // For a single-stream analysis, the result is not useful if error is null + // and rcf score (e.g., thus anomaly grade/confidence/forecasts) is null. + // For a HC analysis, we don't need to save on the detector level. + // We return 0 or Double.NaN rcf score if there is no error. + return super.shouldSave() || anomalyScore > 0; + } } diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java index d7454bcda..8350e9296 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/AnomalyResultTransportAction.java @@ -11,136 +11,58 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_SEARCH_QUERY_MSG; - -import java.net.ConnectException; -import java.util.ArrayList; import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ExceptionsHelper; -import org.opensearch.OpenSearchTimeoutException; import org.opensearch.action.ActionListener; -import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; -import org.opensearch.action.search.SearchPhaseExecutionException; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.IndicesOptions; -import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.CompositeRetriever.PageIterator; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.SingleStreamModelIdMapper; -import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStats; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; -import org.opensearch.cluster.ClusterState; -import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; -import org.opensearch.common.lease.Releasable; import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.NetworkExceptionHelper; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.index.IndexNotFoundException; -import org.opensearch.node.NodeClosedException; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; -import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.model.FeatureData; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.transport.ActionNotFoundTransportException; -import org.opensearch.transport.ConnectTransportException; -import org.opensearch.transport.NodeNotConnectedException; -import org.opensearch.transport.ReceiveTimeoutTransportException; -import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; public class AnomalyResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(AnomalyResultTransportAction.class); - static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; - static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; - static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; - static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; - static final String NULL_RESPONSE = "Received null response from"; - - static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; - static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; - - private final TransportService transportService; - private final NodeStateManager stateManager; - private final FeatureManager featureManager; - private final ModelManager modelManager; - private final HashRing hashRing; - private final TransportRequestOptions option; - private final ClusterService clusterService; - private final IndexNameExpressionResolver indexNameExpressionResolver; - private final ADStats adStats; - private final ADCircuitBreakerService adCircuitBreakerService; - private final ThreadPool threadPool; + private ADResultProcessor resultProcessor; private final Client client; - private final SecurityClientUtil clientUtil; - private final ADTaskManager adTaskManager; - + private CircuitBreakerService adCircuitBreakerService; // Cache HC detector id. This is used to count HC failure stats. We can tell a detector // is HC or not by checking if detector id exists in this field or not. Will add // detector id to this field when start to run realtime detection and remove detector // id once realtime detection done. private final Set hcDetectors; - private NamedXContentRegistry xContentRegistry; - private Settings settings; - // within an interval, how many percents are used to process requests. - // 1.0 means we use all of the detection interval to process requests. - // to ensure we don't block next interval, it is better to set it less than 1.0. - private final float intervalRatioForRequest; - private int maxEntitiesPerInterval; - private int pageSize; + private final Stats adStats; + private final NodeStateManager nodeStateManager; @Inject public AnomalyResultTransportAction( @@ -149,47 +71,45 @@ public AnomalyResultTransportAction( Settings settings, Client client, SecurityClientUtil clientUtil, - NodeStateManager manager, + NodeStateManager nodeStateManager, FeatureManager featureManager, - ModelManager modelManager, + ADModelManager modelManager, HashRing hashRing, ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver, - ADCircuitBreakerService adCircuitBreakerService, - ADStats adStats, + CircuitBreakerService adCircuitBreakerService, + Stats adStats, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager + ADTaskManager realTimeTaskManager ) { super(AnomalyResultAction.NAME, transportService, actionFilters, AnomalyResultRequest::new); - this.transportService = transportService; - this.settings = settings; + this.resultProcessor = new ADResultProcessor( + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityADResultAction.NAME, + StatNames.AD_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + adStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + AnomalyResultResponse.class, + featureManager, + modelManager + ); this.client = client; - this.clientUtil = clientUtil; - this.stateManager = manager; - this.featureManager = featureManager; - this.modelManager = modelManager; - this.hashRing = hashRing; - this.option = TransportRequestOptions - .builder() - .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) - .build(); - this.clusterService = clusterService; - this.indexNameExpressionResolver = indexNameExpressionResolver; this.adCircuitBreakerService = adCircuitBreakerService; - this.adStats = adStats; - this.threadPool = threadPool; this.hcDetectors = new HashSet<>(); - this.xContentRegistry = xContentRegistry; - this.intervalRatioForRequest = AnomalyDetectorSettings.INTERVAL_RATIO_FOR_REQUESTS; - - this.maxEntitiesPerInterval = MAX_ENTITIES_PER_QUERY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_PER_QUERY, it -> maxEntitiesPerInterval = it); - - this.pageSize = PAGE_SIZE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(PAGE_SIZE, it -> pageSize = it); - this.adTaskManager = adTaskManager; + this.adStats = adStats; + this.nodeStateManager = nodeStateManager; } /** @@ -246,7 +166,7 @@ public AnomalyResultTransportAction( protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { AnomalyResultRequest request = AnomalyResultRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + String adID = request.getConfigId(); ActionListener original = listener; listener = ActionListener.wrap(r -> { hcDetectors.remove(adID); @@ -275,864 +195,13 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< return; } try { - stateManager.getAnomalyDetector(adID, onGetDetector(listener, adID, request)); + nodeStateManager.getConfig(adID, AnalysisType.AD, resultProcessor.onGetConfig(listener, adID, request, hcDetectors)); } catch (Exception ex) { - handleExecuteException(ex, listener, adID); + ResultProcessor.handleExecuteException(ex, listener, adID); } } catch (Exception e) { LOG.error(e); listener.onFailure(e); } } - - /** - * didn't use ActionListener.wrap so that I can - * 1) use this to refer to the listener inside the listener - * 2) pass parameters using constructors - * - */ - class PageListener implements ActionListener { - private PageIterator pageIterator; - private String detectorId; - private long dataStartTime; - private long dataEndTime; - - PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { - this.pageIterator = pageIterator; - this.detectorId = detectorId; - this.dataStartTime = dataStartTime; - this.dataEndTime = dataEndTime; - } - - @Override - public void onResponse(CompositeRetriever.Page entityFeatures) { - if (pageIterator.hasNext()) { - pageIterator.next(this); - } - if (entityFeatures != null && false == entityFeatures.isEmpty()) { - // wrap expensive operation inside ad threadpool - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> { - try { - - Set>> node2Entities = entityFeatures - .getResults() - .entrySet() - .stream() - .filter(e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).isPresent()) - .collect( - Collectors - .groupingBy( - // from entity name to its node - e -> hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(e.getKey().toString()).get(), - Collectors.toMap(Entry::getKey, Entry::getValue) - ) - ) - .entrySet(); - - Iterator>> iterator = node2Entities.iterator(); - - while (iterator.hasNext()) { - Entry> entry = iterator.next(); - DiscoveryNode modelNode = entry.getKey(); - if (modelNode == null) { - iterator.remove(); - continue; - } - String modelNodeId = modelNode.getId(); - if (stateManager.isMuted(modelNodeId, detectorId)) { - LOG - .info( - String - .format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", modelNodeId, detectorId) - ); - iterator.remove(); - } - } - - final AtomicReference failure = new AtomicReference<>(); - node2Entities.stream().forEach(nodeEntity -> { - DiscoveryNode node = nodeEntity.getKey(); - transportService - .sendRequest( - node, - EntityResultAction.NAME, - new EntityResultRequest(detectorId, nodeEntity.getValue(), dataStartTime, dataEndTime), - option, - new ActionListenerResponseHandler<>( - new EntityResultListener(node.getId(), detectorId, failure), - AcknowledgedResponse::new, - ThreadPool.Names.SAME - ) - ); - }); - - } catch (Exception e) { - LOG.error("Unexpected exception", e); - handleException(e); - } - }); - } - } - - @Override - public void onFailure(Exception e) { - LOG.error("Unexpetected exception", e); - handleException(e); - } - - private void handleException(Exception e) { - Exception convertedException = convertedQueryFailureException(e, detectorId); - if (false == (convertedException instanceof TimeSeriesException)) { - Throwable cause = ExceptionsHelper.unwrapCause(convertedException); - convertedException = new InternalFailure(detectorId, cause); - } - stateManager.setException(detectorId, convertedException); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String adID, - AnomalyResultRequest request - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(adID, "AnomalyDetector is not available.", true)); - return; - } - - AnomalyDetector anomalyDetector = detectorOptional.get(); - if (anomalyDetector.isHighCardinality()) { - hcDetectors.add(adID); - adStats.getStat(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName()).increment(); - } - - long delayMillis = Optional - .ofNullable((IntervalTimeConfiguration) anomalyDetector.getWindowDelay()) - .map(t -> t.toDuration().toMillis()) - .orElse(0L); - long dataStartTime = request.getStart() - delayMillis; - long dataEndTime = request.getEnd() - delayMillis; - - adTaskManager - .initRealtimeTaskCacheAndCleanupStaleCache( - adID, - anomalyDetector, - transportService, - ActionListener - .runAfter( - initRealtimeTaskCacheListener(adID), - () -> executeAnomalyDetection(listener, adID, request, anomalyDetector, dataStartTime, dataEndTime) - ) - ); - }, exception -> handleExecuteException(exception, listener, adID)); - } - - private ActionListener initRealtimeTaskCacheListener(String detectorId) { - return ActionListener.wrap(r -> { - if (r) { - LOG.debug("Realtime task cache initied for detector {}", detectorId); - } - }, e -> LOG.error("Failed to init realtime task cache for " + detectorId, e)); - } - - private void executeAnomalyDetection( - ActionListener listener, - String adID, - AnomalyResultRequest request, - AnomalyDetector anomalyDetector, - long dataStartTime, - long dataEndTime - ) { - // HC logic starts here - if (anomalyDetector.isHighCardinality()) { - Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of [{}]", adID), exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - } - - // assume request are in epoch milliseconds - long nextDetectionStartTime = request.getEnd() + (long) (anomalyDetector.getIntervalInMilliseconds() * intervalRatioForRequest); - - CompositeRetriever compositeRetriever = new CompositeRetriever( - dataStartTime, - dataEndTime, - anomalyDetector, - xContentRegistry, - client, - clientUtil, - nextDetectionStartTime, - settings, - maxEntitiesPerInterval, - pageSize, - indexNameExpressionResolver, - clusterService - ); - - PageIterator pageIterator = null; - - try { - pageIterator = compositeRetriever.iterator(); - } catch (Exception e) { - listener.onFailure(new EndRunException(anomalyDetector.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); - return; - } - - PageListener getEntityFeatureslistener = new PageListener(pageIterator, adID, dataStartTime, dataEndTime); - if (pageIterator.hasNext()) { - pageIterator.next(getEntityFeatureslistener); - } - - // We don't know when the pagination will not finish. To not - // block the following interval request to start, we return immediately. - // Pagination will stop itself when the time is up. - if (previousException.isPresent()) { - listener.onFailure(previousException.get()); - } else { - listener - .onResponse( - new AnomalyResultResponse(new ArrayList(), null, null, anomalyDetector.getIntervalInMinutes(), true) - ); - } - return; - } - - // HC logic ends and single entity logic starts here - // We are going to use only 1 model partition for a single stream detector. - // That's why we use 0 here. - String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(adID, 0); - Optional asRCFNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); - if (!asRCFNode.isPresent()) { - listener.onFailure(new InternalFailure(adID, "RCF model node is not available.")); - return; - } - - DiscoveryNode rcfNode = asRCFNode.get(); - - // we have already returned listener inside shouldStart method - if (!shouldStart(listener, adID, anomalyDetector, rcfNode.getId(), rcfModelID)) { - return; - } - - featureManager - .getCurrentFeatures( - anomalyDetector, - dataStartTime, - dataEndTime, - onFeatureResponseForSingleEntityDetector(adID, anomalyDetector, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime) - ); - } - - // For single entity detector - private ActionListener onFeatureResponseForSingleEntityDetector( - String adID, - AnomalyDetector detector, - ActionListener listener, - String rcfModelId, - DiscoveryNode rcfNode, - long dataStartTime, - long dataEndTime - ) { - return ActionListener.wrap(featureOptional -> { - List featureInResponse = null; - if (featureOptional.getUnprocessedFeatures().isPresent()) { - featureInResponse = ParseUtils.getFeatureData(featureOptional.getUnprocessedFeatures().get(), detector); - } - - if (!featureOptional.getProcessedFeatures().isPresent()) { - Optional exception = coldStartIfNoCheckPoint(detector); - if (exception.isPresent()) { - listener.onFailure(exception.get()); - return; - } - - if (!featureOptional.getUnprocessedFeatures().isPresent()) { - // Feature not available is common when we have data holes. Respond empty response - // and don't log to avoid bloating our logs. - LOG.debug("No data in current detection window between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse( - new ArrayList(), - "No data in current detection window", - null, - null, - false - ) - ); - } else { - LOG.debug("Return at least current feature value between {} and {} for {}", dataStartTime, dataEndTime, adID); - listener - .onResponse( - new AnomalyResultResponse(featureInResponse, "No full shingle in current detection window", null, null, false) - ); - } - return; - } - - final AtomicReference failure = new AtomicReference(); - - LOG.info("Sending RCF request to {} for model {}", rcfNode.getId(), rcfModelId); - - RCFActionListener rcfListener = new RCFActionListener( - rcfModelId, - failure, - rcfNode.getId(), - detector, - listener, - featureInResponse, - adID - ); - - transportService - .sendRequest( - rcfNode, - RCFResultAction.NAME, - new RCFResultRequest(adID, rcfModelId, featureOptional.getProcessedFeatures().get()), - option, - new ActionListenerResponseHandler<>(rcfListener, RCFResultResponse::new) - ); - }, exception -> { handleQueryFailure(exception, listener, adID); }); - } - - private void handleQueryFailure(Exception exception, ActionListener listener, String adID) { - Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); - - if (convertedQueryFailureException instanceof EndRunException) { - // invalid feature query - listener.onFailure(convertedQueryFailureException); - } else { - handleExecuteException(convertedQueryFailureException, listener, adID); - } - } - - /** - * Convert a query related exception to EndRunException - * - * These query exception can happen during the starting phase of the OpenSearch - * process. Thus, set the stopNow parameter of these EndRunException to false - * and confirm the EndRunException is not a false positive. - * - * @param exception Exception - * @param adID detector Id - * @return the converted exception if the exception is query related - */ - private Exception convertedQueryFailureException(Exception exception, String adID) { - if (ExceptionUtil.isIndexNotAvailable(exception)) { - return new EndRunException(adID, TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false).countedInStats(false); - } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { - // This is to catch invalid aggregation on wrong field type. For example, - // sum aggregation on text field. We should end detector run for such case. - return new EndRunException( - adID, - INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), - exception, - false - ).countedInStats(false); - } - - return exception; - } - - /** - * Verify failure of rcf or threshold models. If there is no model, trigger cold - * start. If there is an exception for the previous cold start of this detector, - * throw exception to the caller. - * - * @param failure object that may contain exceptions thrown - * @param detector detector object - * @return exception if AD job execution gets resource not found exception - * @throws Exception when the input failure is not a ResourceNotFoundException. - * List of exceptions we can throw - * 1. Exception from cold start: - * 1). InternalFailure due to - * a. OpenSearchTimeoutException thrown by putModelCheckpoint during cold start - * 2). EndRunException with endNow equal to false - * a. training data not available - * b. cold start cannot succeed - * c. invalid training data - * 3) EndRunException with endNow equal to true - * a. invalid search query - * 2. LimitExceededException from one of RCF model node when the total size of the models - * is more than X% of heap memory. - * 3. InternalFailure wrapping OpenSearchTimeoutException inside caused by - * RCF/Threshold model node failing to get checkpoint to restore model before timeout. - */ - private Exception coldStartIfNoModel(AtomicReference failure, AnomalyDetector detector) throws Exception { - Exception exp = failure.get(); - if (exp == null) { - return null; - } - - // return exceptions like LimitExceededException to caller - if (!(exp instanceof ResourceNotFoundException)) { - return exp; - } - - // fetch previous cold start exception - String adID = detector.getId(); - final Optional previousException = stateManager.fetchExceptionAndClear(adID); - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", () -> adID, () -> exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return exception; - } - } - LOG.info("Trigger cold start for {}", detector.getId()); - coldStart(detector); - return previousException.orElse(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - - private void findException(Throwable cause, String adID, AtomicReference failure, String nodeId) { - if (cause == null) { - LOG.error(new ParameterizedMessage("Null input exception")); - return; - } - if (cause instanceof Error) { - // we cannot do anything with Error. - LOG.error(new ParameterizedMessage("Error during prediction for {}: ", adID), cause); - return; - } - - Exception causeException = (Exception) cause; - - if (causeException instanceof TimeSeriesException) { - failure.set(causeException); - } else if (causeException instanceof NotSerializableExceptionWrapper) { - // we only expect this happens on AD exceptions - Optional actualException = NotSerializedExceptionName - .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, adID); - if (actualException.isPresent()) { - TimeSeriesException adException = actualException.get(); - failure.set(adException); - if (adException instanceof ResourceNotFoundException) { - // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF - // 1.0 - // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node - // after consecutive failures. - stateManager.addPressure(nodeId, adID); - } - } else { - // some unexpected bugs occur while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } else if (causeException instanceof IndexNotFoundException - && causeException.getMessage().contains(ADCommonName.CHECKPOINT_INDEX_NAME)) { - // checkpoint index does not exist - // ResourceNotFoundException will trigger cold start later - failure.set(new ResourceNotFoundException(adID, causeException.getMessage())); - } else if (causeException instanceof OpenSearchTimeoutException) { - // we can have OpenSearchTimeoutException when a node tries to load RCF or - // threshold model - failure.set(new InternalFailure(adID, causeException)); - } else if (causeException instanceof IllegalArgumentException) { - // we can have IllegalArgumentException when a model is corrupted - failure.set(new InternalFailure(adID, causeException)); - } else { - // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. - // NoShardAvailableActionException) while predicting anomaly - failure.set(new EndRunException(adID, CommonMessages.BUG_RESPONSE, causeException, false)); - } - } - - void handleExecuteException(Exception ex, ActionListener listener, String adID) { - if (ex instanceof ClientException) { - listener.onFailure(ex); - } else if (ex instanceof TimeSeriesException) { - listener.onFailure(new InternalFailure((TimeSeriesException) ex)); - } else { - Throwable cause = ExceptionsHelper.unwrapCause(ex); - listener.onFailure(new InternalFailure(adID, cause)); - } - } - - private boolean invalidQuery(SearchPhaseExecutionException ex) { - // If all shards return bad request and failure cause is IllegalArgumentException, we - // consider the feature query is invalid and will not count the error in failure stats. - for (ShardSearchFailure failure : ex.shardFailures()) { - if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { - return false; - } - } - return true; - } - - // For single entity detector - class RCFActionListener implements ActionListener { - private String modelID; - private AtomicReference failure; - private String rcfNodeID; - private AnomalyDetector detector; - private ActionListener listener; - private List featureInResponse; - private final String adID; - - RCFActionListener( - String modelID, - AtomicReference failure, - String rcfNodeID, - AnomalyDetector detector, - ActionListener listener, - List features, - String adID - ) { - this.modelID = modelID; - this.failure = failure; - this.rcfNodeID = rcfNodeID; - this.detector = detector; - this.listener = listener; - this.featureInResponse = features; - this.adID = adID; - } - - @Override - public void onResponse(RCFResultResponse response) { - try { - stateManager.resetBackpressureCounter(rcfNodeID, adID); - if (response != null) { - listener - .onResponse( - new AnomalyResultResponse( - response.getAnomalyGrade(), - response.getConfidence(), - response.getRCFScore(), - featureInResponse, - null, - response.getTotalUpdates(), - detector.getIntervalInMinutes(), - false, - response.getRelativeIndex(), - response.getAttribution(), - response.getPastValues(), - response.getExpectedValuesList(), - response.getLikelihoodOfValues(), - response.getThreshold() - ) - ); - } else { - LOG.warn(NULL_RESPONSE + " {} for {}", modelID, rcfNodeID); - listener.onFailure(new InternalFailure(adID, ADCommonMessages.NO_MODEL_ERR_MSG)); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - - @Override - public void onFailure(Exception e) { - try { - handlePredictionFailure(e, adID, rcfNodeID, failure); - Exception exception = coldStartIfNoModel(failure, detector); - if (exception != null) { - listener.onFailure(exception); - } else { - listener.onFailure(new InternalFailure(adID, "Node connection problem or unexpected exception")); - } - } catch (Exception ex) { - LOG.error(new ParameterizedMessage("Unexpected exception for [{}]", adID), ex); - handleExecuteException(ex, listener, adID); - } - } - } - - /** - * Handle a prediction failure. Possibly (i.e., we don't always need to do that) - * convert the exception to a form that AD can recognize and handle and sets the - * input failure reference to the converted exception. - * - * @param e prediction exception - * @param adID Detector Id - * @param nodeID Node Id - * @param failure Parameter to receive the possibly converted function for the - * caller to deal with - */ - private void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { - LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); - if (e == null) { - return; - } - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (hasConnectionIssue(cause)) { - handleConnectionException(nodeID, adID); - } else { - findException(cause, adID, failure, nodeID); - } - } - - /** - * Check if the input exception indicates connection issues. - * During blue-green deployment, we may see ActionNotFoundTransportException. - * Count that as connection issue and isolate that node if it continues to happen. - * - * @param e exception - * @return true if we get disconnected from the node or the node is not in the - * right state (being closed) or transport request times out (sent from TimeoutHandler.run) - */ - private boolean hasConnectionIssue(Throwable e) { - return e instanceof ConnectTransportException - || e instanceof NodeClosedException - || e instanceof ReceiveTimeoutTransportException - || e instanceof NodeNotConnectedException - || e instanceof ConnectException - || NetworkExceptionHelper.isCloseConnectionException(e) - || e instanceof ActionNotFoundTransportException; - } - - private void handleConnectionException(String node, String detectorId) { - final DiscoveryNodes nodes = clusterService.state().nodes(); - if (!nodes.nodeExists(node)) { - hashRing.buildCirclesForRealtimeAD(); - return; - } - // rebuilding is not done or node is unresponsive - stateManager.addPressure(node, detectorId); - } - - /** - * Since we need to read from customer index and write to anomaly result index, - * we need to make sure we can read and write. - * - * @param state Cluster state - * @return whether we have global block or not - */ - private boolean checkGlobalBlock(ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null - || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; - } - - /** - * Similar to checkGlobalBlock, we check block on the indices level. - * - * @param state Cluster state - * @param level block level - * @param indices the indices on which to check block - * @return whether any of the index has block on the level. - */ - private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { - // the original index might be an index expression with wildcards like "log*", - // so we need to expand the expression to concrete index name - String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); - - return state.blocks().indicesBlockedException(level, concreteIndices) != null; - } - - /** - * Check if we should start anomaly prediction. - * - * @param listener listener to respond back to AnomalyResultRequest. - * @param adID detector ID - * @param detector detector instance corresponds to adID - * @param rcfNodeId the rcf model hosting node ID for adID - * @param rcfModelID the rcf model ID for adID - * @return if we can start anomaly prediction. - */ - private boolean shouldStart( - ActionListener listener, - String adID, - AnomalyDetector detector, - String rcfNodeId, - String rcfModelID - ) { - ClusterState state = clusterService.state(); - if (checkGlobalBlock(state)) { - listener.onFailure(new InternalFailure(adID, READ_WRITE_BLOCKED)); - return false; - } - - if (stateManager.isMuted(rcfNodeId, adID)) { - listener - .onFailure( - new InternalFailure( - adID, - String.format(Locale.ROOT, NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) - ) - ); - return false; - } - - if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { - listener.onFailure(new InternalFailure(adID, INDEX_READ_BLOCKED)); - return false; - } - - return true; - } - - private void coldStart(AnomalyDetector detector) { - String detectorId = detector.getId(); - - // If last cold start is not finished, we don't trigger another one - if (stateManager.isColdStartRunning(detectorId)) { - return; - } - - final Releasable coldStartFinishingCallback = stateManager.markColdStartRunning(detectorId); - - ActionListener> listener = ActionListener.wrap(trainingData -> { - if (trainingData.isPresent()) { - double[][] dataPoints = trainingData.get(); - - ActionListener trainModelListener = ActionListener - .wrap(res -> { LOG.info("Succeeded in training {}", detectorId); }, exception -> { - if (exception instanceof TimeSeriesException) { - // e.g., partitioned model exceeds memory limit - stateManager.setException(detectorId, exception); - } else if (exception instanceof IllegalArgumentException) { - // IllegalArgumentException due to invalid training data - stateManager - .setException(detectorId, new EndRunException(detectorId, "Invalid training data", exception, false)); - } else if (exception instanceof OpenSearchTimeoutException) { - stateManager - .setException( - detectorId, - new InternalFailure(detectorId, "Time out while indexing cold start checkpoint", exception) - ); - } else { - stateManager - .setException(detectorId, new EndRunException(detectorId, "Error while training model", exception, false)); - } - }); - - modelManager - .trainModel( - detector, - dataPoints, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - trainModelListener, - false - ) - ); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Cannot get training data", false)); - } - }, exception -> { - if (exception instanceof OpenSearchTimeoutException) { - stateManager.setException(detectorId, new InternalFailure(detectorId, "Time out while getting training data", exception)); - } else if (exception instanceof TimeSeriesException) { - // e.g., Invalid search query - stateManager.setException(detectorId, exception); - } else { - stateManager.setException(detectorId, new EndRunException(detectorId, "Error while cold start", exception, false)); - } - }); - - final ActionListener> listenerWithReleaseCallback = ActionListener - .runAfter(listener, coldStartFinishingCallback::close); - - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute( - () -> featureManager - .getColdStartData( - detector, - new ThreadedActionListener<>( - LOG, - threadPool, - TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME, - listenerWithReleaseCallback, - false - ) - ) - ); - } - - /** - * Check if checkpoint for an detector exists or not. If not and previous - * run is not EndRunException whose endNow is true, trigger cold start. - * @param detector detector object - * @return previous cold start exception - */ - private Optional coldStartIfNoCheckPoint(AnomalyDetector detector) { - String detectorId = detector.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error(new ParameterizedMessage("Previous exception of {}:", detectorId), exception); - if (exception instanceof EndRunException && ((EndRunException) exception).isEndNow()) { - return previousException; - } - } - - stateManager.getDetectorCheckpoint(detectorId, ActionListener.wrap(checkpointExists -> { - if (!checkpointExists) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } - }, exception -> { - Throwable cause = ExceptionsHelper.unwrapCause(exception); - if (cause instanceof IndexNotFoundException) { - LOG.info("Trigger cold start for {}", detectorId); - coldStart(detector); - } else { - String errorMsg = String.format(Locale.ROOT, "Fail to get checkpoint state for %s", detectorId); - LOG.error(errorMsg, exception); - stateManager.setException(detectorId, new TimeSeriesException(errorMsg, exception)); - } - })); - - return previousException; - } - - class EntityResultListener implements ActionListener { - private String nodeId; - private final String adID; - private AtomicReference failure; - - EntityResultListener(String nodeId, String adID, AtomicReference failure) { - this.nodeId = nodeId; - this.adID = adID; - this.failure = failure; - } - - @Override - public void onResponse(AcknowledgedResponse response) { - try { - if (response.isAcknowledged() == false) { - LOG.error("Cannot send entities' features to {} for {}", nodeId, adID); - stateManager.addPressure(nodeId, adID); - } else { - stateManager.resetBackpressureCounter(nodeId, adID); - } - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - @Override - public void onFailure(Exception e) { - try { - // e.g., we have connection issues with all of the nodes while restarting clusters - LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, adID), e); - - handleException(e); - - } catch (Exception ex) { - LOG.error("Unexpected exception: {} for {}", ex, adID); - handleException(ex); - } - } - - private void handleException(Exception e) { - handlePredictionFailure(e, adID, nodeId, failure); - if (failure.get() != null) { - stateManager.setException(adID, failure.get()); - } - } - } } diff --git a/src/main/java/org/opensearch/ad/transport/CronAction.java b/src/main/java/org/opensearch/ad/transport/CronAction.java index 1e64a0f45..0c31fe940 100644 --- a/src/main/java/org/opensearch/ad/transport/CronAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class CronAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "cron"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "cron"; public static final CronAction INSTANCE = new CronAction(); private CronAction() { diff --git a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java index edc21cd6f..8b7a467f1 100644 --- a/src/main/java/org/opensearch/ad/transport/CronTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/CronTransportAction.java @@ -20,26 +20,32 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.task.ForecastTaskManager; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; public class CronTransportAction extends TransportNodesAction { private final Logger LOG = LogManager.getLogger(CronTransportAction.class); private NodeStateManager transportStateManager; - private ModelManager modelManager; + private ADModelManager adModelManager; private FeatureManager featureManager; - private CacheProvider cacheProvider; - private EntityColdStarter entityColdStarter; + private ADCacheProvider adCacheProvider; + private ForecastCacheProvider forecastCacheProvider; + private ADEntityColdStart adEntityColdStarter; + private ForecastColdStart forecastColdStarter; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; @Inject public CronTransportAction( @@ -48,11 +54,14 @@ public CronTransportAction( TransportService transportService, ActionFilters actionFilters, NodeStateManager tarnsportStatemanager, - ModelManager modelManager, + ADModelManager adModelManager, FeatureManager featureManager, - CacheProvider cacheProvider, - EntityColdStarter entityColdStarter, - ADTaskManager adTaskManager + ADCacheProvider adCacheProvider, + ForecastCacheProvider forecastCacheProvider, + ADEntityColdStart adEntityColdStarter, + ForecastColdStart forecastColdStarter, + ADTaskManager adTaskManager, + ForecastTaskManager forecastTaskManager ) { super( CronAction.NAME, @@ -66,11 +75,14 @@ public CronTransportAction( CronNodeResponse.class ); this.transportStateManager = tarnsportStatemanager; - this.modelManager = modelManager; + this.adModelManager = adModelManager; this.featureManager = featureManager; - this.cacheProvider = cacheProvider; - this.entityColdStarter = entityColdStarter; + this.adCacheProvider = adCacheProvider; + this.forecastCacheProvider = forecastCacheProvider; + this.adEntityColdStarter = adEntityColdStarter; + this.forecastColdStarter = forecastColdStarter; this.adTaskManager = adTaskManager; + this.forecastTaskManager = forecastTaskManager; } @Override @@ -97,22 +109,22 @@ protected CronNodeResponse newNodeResponse(StreamInput in) throws IOException { */ @Override protected CronNodeResponse nodeOperation(CronNodeRequest request) { - LOG.info("Start running AD hourly cron."); + LOG.info("Start running hourly cron."); + // ====================== + // AD + // ====================== // makes checkpoints for hosted models and stop hosting models not actively // used. // for single-entity detector - modelManager - .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining model", e))); + adModelManager + .maintenance(ActionListener.wrap(v -> LOG.debug("model maintenance done"), e -> LOG.error("Error maintaining ad model", e))); // for multi-entity detector - cacheProvider.get().maintenance(); + adCacheProvider.get().maintenance(); // delete unused buffered shingle data featureManager.maintenance(); - // delete unused transport state - transportStateManager.maintenance(); - - entityColdStarter.maintenance(); + adEntityColdStarter.maintenance(); // clean child tasks and AD results of deleted detector level task adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); @@ -125,6 +137,20 @@ protected CronNodeResponse nodeOperation(CronNodeRequest request) { // maintain running realtime tasks: clean stale running realtime task cache adTaskManager.maintainRunningRealtimeTasks(); + // ====================== + // Forecast + // ====================== + forecastCacheProvider.get().maintenance(); + forecastColdStarter.maintenance(); + // clean child tasks and forecast results of deleted forecaster level task + forecastTaskManager.cleanChildTasksAndResultsOfDeletedTask(); + + // ====================== + // Common + // ====================== + // delete unused transport state + transportStateManager.maintenance(); + return new CronNodeResponse(clusterService.localNode()); } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java similarity index 55% rename from src/main/java/org/opensearch/ad/transport/DeleteModelAction.java rename to src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java index 3af6982b0..c4eeef176 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelAction.java @@ -12,14 +12,15 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; -public class DeleteModelAction extends ActionType { +public class DeleteADModelAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; - public static final DeleteModelAction INSTANCE = new DeleteModelAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteADModelAction INSTANCE = new DeleteADModelAction(); - private DeleteModelAction() { + private DeleteADModelAction() { super(NAME, DeleteModelResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java new file mode 100644 index 000000000..3be7b649c --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/DeleteADModelTransportAction.java @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.task.ADTaskCacheManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class DeleteADModelTransportAction extends + BaseDeleteModelTransportAction { + private static final Logger LOG = LogManager.getLogger(DeleteADModelTransportAction.class); + private ADModelManager modelManager; + private FeatureManager featureManager; + + @Inject + public DeleteADModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ADModelManager modelManager, + FeatureManager featureManager, + ADCacheProvider cache, + ADTaskCacheManager adTaskCacheManager, + ADEntityColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + adTaskCacheManager, + coldStarter, + DeleteADModelAction.NAME + ); + this.modelManager = modelManager; + this.featureManager = featureManager; + } + + /** + * + * Delete checkpoint document (including both RCF and thresholding model), in-memory models, + * buffered shingle data, transport state, and anomaly result + * + * @param request delete request + * @return delete response including local node Id. + */ + @Override + protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { + super.nodeOperation(request); + String adID = request.getConfigID(); + + // delete in-memory models and model checkpoint + modelManager + .clear( + adID, + ActionListener + .wrap( + r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), + e -> LOG.error("Fail to delete model for " + adID, e) + ) + ); + + // delete buffered shingle data + featureManager.clear(adID); + + LOG.info("Finished deleting ad models for {}", adID); + return new DeleteModelNodeResponse(clusterService.localNode()); + } + +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java index 75dc34638..70d655507 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.delete.DeleteResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class DeleteAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/delete"; public static final DeleteAnomalyDetectorAction INSTANCE = new DeleteAnomalyDetectorAction(); private DeleteAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java index ebc7577f0..f7db2fe8f 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTransportAction.java @@ -13,9 +13,8 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_DETECTOR; import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -34,7 +33,7 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; @@ -48,8 +47,12 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexNotFoundException; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -62,6 +65,7 @@ public class DeleteAnomalyDetectorTransportAction extends HandledTransportAction private NamedXContentRegistry xContentRegistry; private final ADTaskManager adTaskManager; private volatile Boolean filterByEnabled; + private final NodeStateManager nodeStateManager; @Inject public DeleteAnomalyDetectorTransportAction( @@ -71,6 +75,7 @@ public DeleteAnomalyDetectorTransportAction( ClusterService clusterService, Settings settings, NamedXContentRegistry xContentRegistry, + NodeStateManager nodeStateManager, ADTaskManager adTaskManager ) { super(DeleteAnomalyDetectorAction.NAME, transportService, actionFilters, DeleteAnomalyDetectorRequest::new); @@ -79,15 +84,16 @@ public DeleteAnomalyDetectorTransportAction( this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.adTaskManager = adTaskManager; - filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.nodeStateManager = nodeStateManager; + filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); } @Override protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, ActionListener actionListener) { String detectorId = request.getDetectorID(); LOG.info("Delete anomaly detector job {}", detectorId); - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_DELETE_DETECTOR); // By the time request reaches here, the user permissions are validated by Security plugin. try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -96,7 +102,7 @@ protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, Action detectorId, filterByEnabled, listener, - (anomalyDetector) -> adTaskManager.getDetector(detectorId, detector -> { + (anomalyDetector) -> nodeStateManager.getConfig(detectorId, AnalysisType.AD, detector -> { if (!detector.isPresent()) { // In a mixed cluster, if delete detector request routes to node running AD1.0, then it will // not delete detector tasks. User can re-delete these deleted detector after cluster upgraded, @@ -108,7 +114,7 @@ protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, Action // Check if there is realtime job or historical analysis task running. If none of these running, we // can delete the detector. getDetectorJob(detectorId, listener, () -> { - adTaskManager.getAndExecuteOnLatestDetectorLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { + adTaskManager.getAndExecuteOnLatestConfigLevelTask(detectorId, HISTORICAL_DETECTOR_TASK_TYPES, adTask -> { if (adTask.isPresent() && !adTask.get().isDone()) { listener.onFailure(new OpenSearchStatusException("Detector is running", RestStatus.INTERNAL_SERVER_ERROR)); } else { @@ -119,7 +125,9 @@ protected void doExecute(Task task, DeleteAnomalyDetectorRequest request, Action }, listener), client, clusterService, - xContentRegistry + xContentRegistry, + DeleteResponse.class, + AnomalyDetector.class ); } catch (Exception e) { LOG.error(e); @@ -214,7 +222,7 @@ private void onGetAdJobResponseForWrite(GetResponse response, ActionListener { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "results/delete"; public static final DeleteAnomalyResultsAction INSTANCE = new DeleteAnomalyResultsAction(); private DeleteAnomalyResultsAction() { diff --git a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java index 69b12ab0c..e96cc68a1 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/DeleteAnomalyResultsTransportAction.java @@ -12,9 +12,7 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_DELETE_AD_RESULT; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import org.apache.logging.log4j.LogManager; @@ -32,6 +30,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.transport.TransportService; public class DeleteAnomalyResultsTransportAction extends HandledTransportAction { @@ -50,8 +49,8 @@ public DeleteAnomalyResultsTransportAction( ) { super(DeleteAnomalyResultsAction.NAME, transportService, actionFilters, DeleteByQueryRequest::new); this.client = client; - filterEnabled = FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); + filterEnabled = AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); } @Override @@ -61,7 +60,7 @@ protected void doExecute(Task task, DeleteByQueryRequest request, ActionListener } public void delete(DeleteByQueryRequest request, ActionListener listener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { validateRole(request, user, listener); } catch (Exception e) { @@ -79,7 +78,7 @@ private void validateRole(DeleteByQueryRequest request, User user, ActionListene } else { // Security is enabled and backend role filter is enabled try { - addUserBackendRolesFilter(user, request.getSearchRequest().source()); + ParseUtils.addUserBackendRolesFilter(user, request.getSearchRequest().source()); client.execute(DeleteByQueryAction.INSTANCE, request, listener); } catch (Exception e) { listener.onFailure(e); diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java similarity index 62% rename from src/main/java/org/opensearch/ad/transport/EntityResultAction.java rename to src/main/java/org/opensearch/ad/transport/EntityADResultAction.java index c519858b4..f17c23416 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultAction.java @@ -13,14 +13,14 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; -public class EntityResultAction extends ActionType { +public class EntityADResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; - public static final EntityResultAction INSTANCE = new EntityResultAction(); + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityADResultAction INSTANCE = new EntityADResultAction(); - private EntityResultAction() { + private EntityADResultAction() { super(NAME, AcknowledgedResponse::new); } diff --git a/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java new file mode 100644 index 000000000..cc5750aca --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/EntityADResultTransportAction.java @@ -0,0 +1,160 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ThresholdingResult; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Entry-point for HCAD workflow. We have created multiple queues for + * coordinating the workflow. The overrall workflow is: 1. We store as many + * frequently used entity models in a cache as allowed by the memory limit (10% + * heap). If an entity feature is a hit, we use the in-memory model to detect + * anomalies and record results using the result write queue. 2. If an entity + * feature is a miss, we check if there is free memory or any other entity's + * model can be evacuated. An in-memory entity's frequency may be lower compared + * to the cache miss entity. If that's the case, we replace the lower frequency + * entity's model with the higher frequency entity's model. To load the higher + * frequency entity's model, we first check if a model exists on disk by sending + * a checkpoint read queue request. If there is a checkpoint, we load it to + * memory, perform detection, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the + * checkpoint write queue. 3. We also have the cold entity queue configured for + * cold entities, and the model training and inference are connected by serial + * juxtaposition to limit resource usage. + */ +public class EntityADResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityADResultTransportAction.class); + private CircuitBreakerService adCircuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + @Inject + public EntityADResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ADModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ADCacheProvider entityCache, + NodeStateManager stateManager, + ADIndexManagement indexUtil, + ADResultWriteWorker resultWriteQueue, + ADCheckpointReadWorker checkpointReadQueue, + ADColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ADColdStartWorker entityColdStartWorker, + Stats timeSeriesStats + ) { + super(EntityADResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.adCircuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + this.intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + ADIndex.RESULT, + indexUtil, + resultWriteQueue, + ADResultWriteRequest.class, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue + ); + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (adCircuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String detectorId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(detectorId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", detectorId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, detectorId); + } + + stateManager + .getConfig( + detectorId, + AnalysisType.AD, + intervalDataProcessor.onGetConfig(listener, detectorId, request, previousException) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java b/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java index c699d9a03..d140785ea 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class EntityProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile/entity"; public static final EntityProfileAction INSTANCE = new EntityProfileAction(); private EntityProfileAction() { diff --git a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java index 0a124360d..f649891a6 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/EntityProfileTransportAction.java @@ -21,11 +21,8 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; @@ -35,13 +32,18 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * Transport action to get entity profile. */ @@ -56,7 +58,7 @@ public class EntityProfileTransportAction extends HandledTransportAction cacheProvider; @Inject public EntityProfileTransportAction( @@ -65,7 +67,7 @@ public EntityProfileTransportAction( Settings settings, HashRing hashRing, ClusterService clusterService, - CacheProvider cacheProvider + CacheProvider cacheProvider ) { super(EntityProfileAction.NAME, transportService, actionFilters, EntityProfileRequest::new); this.transportService = transportService; @@ -73,7 +75,7 @@ public EntityProfileTransportAction( this.option = TransportRequestOptions .builder() .withType(TransportRequestOptions.Type.REG) - .withTimeout(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings)) + .withTimeout(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings)) .build(); this.clusterService = clusterService; this.cacheProvider = cacheProvider; @@ -91,7 +93,7 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener } // we use entity's toString (e.g., app_0) to find its node // This should be consistent with how we land a model node in AnomalyResultTransportAction - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(entityValue.toString()); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime(entityValue.toString()); if (false == node.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; @@ -100,12 +102,12 @@ protected void doExecute(Task task, EntityProfileRequest request, ActionListener String modelId = modelIdOptional.get(); DiscoveryNode localNode = clusterService.localNode(); if (localNode.getId().equals(nodeId)) { - EntityCache cache = cacheProvider.get(); + ADPriorityCache cache = cacheProvider.get(); Set profilesToCollect = request.getProfilesToCollect(); EntityProfileResponse.Builder builder = new EntityProfileResponse.Builder(); if (profilesToCollect.contains(EntityProfileName.ENTITY_INFO)) { builder.setActive(cache.isActive(adID, modelId)); - builder.setLastActiveMs(cache.getLastActiveMs(adID, modelId)); + builder.setLastActiveMs(cache.getLastActiveTime(adID, modelId)); } if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS) || profilesToCollect.contains(EntityProfileName.STATE)) { builder.setTotalUpdates(cache.getTotalUpdates(adID, modelId)); diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java deleted file mode 100644 index fd48b302b..000000000 --- a/src/main/java/org/opensearch/ad/transport/EntityResultTransportAction.java +++ /dev/null @@ -1,354 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.transport; - -import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.action.ActionListener; -import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.util.ExceptionUtil; -import org.opensearch.common.inject.Inject; -import org.opensearch.tasks.Task; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; -import org.opensearch.timeseries.common.exception.EndRunException; -import org.opensearch.timeseries.common.exception.LimitExceededException; -import org.opensearch.timeseries.constant.CommonMessages; -import org.opensearch.timeseries.model.Entity; -import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ParseUtils; -import org.opensearch.transport.TransportService; - -/** - * Entry-point for HCAD workflow. We have created multiple queues for coordinating - * the workflow. The overrall workflow is: - * 1. We store as many frequently used entity models in a cache as allowed by the - * memory limit (10% heap). If an entity feature is a hit, we use the in-memory model - * to detect anomalies and record results using the result write queue. - * 2. If an entity feature is a miss, we check if there is free memory or any other - * entity's model can be evacuated. An in-memory entity's frequency may be lower - * compared to the cache miss entity. If that's the case, we replace the lower - * frequency entity's model with the higher frequency entity's model. To load the - * higher frequency entity's model, we first check if a model exists on disk by - * sending a checkpoint read queue request. If there is a checkpoint, we load it - * to memory, perform detection, and save the result using the result write queue. - * Otherwise, we enqueue a cold start request to the cold start queue for model - * training. If training is successful, we save the learned model via the checkpoint - * write queue. - * 3. We also have the cold entity queue configured for cold entities, and the model - * training and inference are connected by serial juxtaposition to limit resource usage. - */ -public class EntityResultTransportAction extends HandledTransportAction { - - private static final Logger LOG = LogManager.getLogger(EntityResultTransportAction.class); - private ModelManager modelManager; - private ADCircuitBreakerService adCircuitBreakerService; - private CacheProvider cache; - private final NodeStateManager stateManager; - private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private ColdEntityWorker coldEntityQueue; - private ThreadPool threadPool; - private EntityColdStartWorker entityColdStartWorker; - private ADStats adStats; - - @Inject - public EntityResultTransportAction( - ActionFilters actionFilters, - TransportService transportService, - ModelManager manager, - ADCircuitBreakerService adCircuitBreakerService, - CacheProvider entityCache, - NodeStateManager stateManager, - ADIndexManagement indexUtil, - ResultWriteWorker resultWriteQueue, - CheckpointReadWorker checkpointReadQueue, - ColdEntityWorker coldEntityQueue, - ThreadPool threadPool, - EntityColdStartWorker entityColdStartWorker, - ADStats adStats - ) { - super(EntityResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); - this.modelManager = manager; - this.adCircuitBreakerService = adCircuitBreakerService; - this.cache = entityCache; - this.stateManager = stateManager; - this.indexUtil = indexUtil; - this.resultWriteQueue = resultWriteQueue; - this.checkpointReadQueue = checkpointReadQueue; - this.coldEntityQueue = coldEntityQueue; - this.threadPool = threadPool; - this.entityColdStartWorker = entityColdStartWorker; - this.adStats = adStats; - } - - @Override - protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { - if (adCircuitBreakerService.isOpen()) { - threadPool - .executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME) - .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); - listener.onFailure(new LimitExceededException(request.getId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); - return; - } - - try { - String detectorId = request.getId(); - - Optional previousException = stateManager.fetchExceptionAndClear(detectorId); - - if (previousException.isPresent()) { - Exception exception = previousException.get(); - LOG.error("Previous exception of {}: {}", detectorId, exception); - if (exception instanceof EndRunException) { - EndRunException endRunException = (EndRunException) exception; - if (endRunException.isEndNow()) { - listener.onFailure(exception); - return; - } - } - - listener = ExceptionUtil.wrapListener(listener, exception, detectorId); - } - - stateManager.getAnomalyDetector(detectorId, onGetDetector(listener, detectorId, request, previousException)); - } catch (Exception exception) { - LOG.error("fail to get entity's anomaly grade", exception); - listener.onFailure(exception); - } - } - - private ActionListener> onGetDetector( - ActionListener listener, - String detectorId, - EntityResultRequest request, - Optional prevException - ) { - return ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - listener.onFailure(new EndRunException(detectorId, "AnomalyDetector is not available.", false)); - return; - } - - AnomalyDetector detector = detectorOptional.get(); - - if (request.getEntities() == null) { - listener.onFailure(new EndRunException(detectorId, "Fail to get any entities from request.", false)); - return; - } - - Instant executionStartTime = Instant.now(); - Map cacheMissEntities = new HashMap<>(); - for (Entry entityEntry : request.getEntities().entrySet()) { - Entity categoricalValues = entityEntry.getKey(); - - if (isEntityFromOldNodeMsg(categoricalValues) - && detector.getCategoryFields() != null - && detector.getCategoryFields().size() == 1) { - Map attrValues = categoricalValues.getAttributes(); - // handle a request from a version before OpenSearch 1.1. - categoricalValues = Entity - .createSingleAttributeEntity(detector.getCategoryFields().get(0), attrValues.get(ADCommonName.EMPTY_FIELD)); - } - - Optional modelIdOptional = categoricalValues.getModelId(detectorId); - if (false == modelIdOptional.isPresent()) { - continue; - } - - String modelId = modelIdOptional.get(); - double[] datapoint = entityEntry.getValue(); - ModelState entityModel = cache.get().get(modelId, detector); - if (entityModel == null) { - // cache miss - cacheMissEntities.put(categoricalValues, datapoint); - continue; - } - try { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(datapoint, entityModel, modelId, categoricalValues, detector.getShingleSize()); - // result.getRcfScore() = 0 means the model is not initialized - // result.getGrade() = 0 means it is not an anomaly - // So many OpenSearchRejectedExecutionException if we write no matter what - if (result.getRcfScore() > 0) { - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(request.getStart()), - Instant.ofEpochMilli(request.getEnd()), - executionStartTime, - Instant.now(), - ParseUtils.getFeatureData(datapoint, detector), - Optional.ofNullable(categoricalValues), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM, - r, - detector.getCustomResultIndex() - ) - ); - } - } - } catch (IllegalArgumentException e) { - // fail to score likely due to model corruption. Re-cold start to recover. - LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - cache.get().removeEntityModel(detectorId, modelId); - entityColdStartWorker - .put( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - RequestPriority.MEDIUM, - categoricalValues, - datapoint, - request.getStart() - ) - ); - } - } - - // split hot and cold entities - Pair, List> hotColdEntities = cache - .get() - .selectUpdateCandidate(cacheMissEntities.keySet(), detectorId, detector); - - List hotEntityRequests = new ArrayList<>(); - List coldEntityRequests = new ArrayList<>(); - - for (Entity hotEntity : hotColdEntities.getLeft()) { - double[] hotEntityValue = cacheMissEntities.get(hotEntity); - if (hotEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); - continue; - } - hotEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // hot entities has MEDIUM priority - RequestPriority.MEDIUM, - hotEntity, - hotEntityValue, - request.getStart() - ) - ); - } - - for (Entity coldEntity : hotColdEntities.getRight()) { - double[] coldEntityValue = cacheMissEntities.get(coldEntity); - if (coldEntityValue == null) { - LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); - continue; - } - coldEntityRequests - .add( - new EntityFeatureRequest( - System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, - // cold entities has LOW priority - RequestPriority.LOW, - coldEntity, - coldEntityValue, - request.getStart() - ) - ); - } - - checkpointReadQueue.putAll(hotEntityRequests); - coldEntityQueue.putAll(coldEntityRequests); - - // respond back - if (prevException.isPresent()) { - listener.onFailure(prevException.get()); - } else { - listener.onResponse(new AcknowledgedResponse(true)); - } - }, exception -> { - LOG - .error( - new ParameterizedMessage( - "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", - detectorId, - request.getStart(), - request.getEnd() - ), - exception - ); - listener.onFailure(exception); - }); - } - - /** - * Whether the received entity comes from an node that doesn't support multi-category fields. - * This can happen during rolling-upgrade or blue/green deployment. - * - * Specifically, when receiving an EntityResultRequest from an incompatible node, - * EntityResultRequest(StreamInput in) gets an String that represents an entity. - * But Entity class requires both an category field name and value. Since we - * don't have access to detector config in EntityResultRequest(StreamInput in), - * we put CommonName.EMPTY_FIELD as the placeholder. In this method, - * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity - * comes from an incompatible node. If it is, we will add the field name back - * as EntityResultTranportAction has access to the detector config object. - * - * @param categoricalValues deserialized Entity from inbound message. - * @return Whether the received entity comes from an node that doesn't support multi-category fields. - */ - private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { - Map attrValues = categoricalValues.getAttributes(); - return (attrValues != null && attrValues.containsKey(ADCommonName.EMPTY_FIELD)); - } -} diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java index 309714cc8..43c62eed3 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskAction.java @@ -14,14 +14,15 @@ import static org.opensearch.ad.constant.ADCommonName.AD_TASK; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class ForwardADTaskAction extends ActionType { +public class ForwardADTaskAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/" + AD_TASK + "/forward"; public static final ForwardADTaskAction INSTANCE = new ForwardADTaskAction(); private ForwardADTaskAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java index adc8e36a8..08c144ea3 100644 --- a/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ForwardADTaskTransportAction.java @@ -11,10 +11,6 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.model.ADTask.ERROR_FIELD; -import static org.opensearch.ad.model.ADTask.STATE_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD; - import java.util.Arrays; import java.util.List; @@ -24,28 +20,32 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.common.inject.Inject; import org.opensearch.commons.authuser.User; import org.opensearch.core.rest.RestStatus; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; -public class ForwardADTaskTransportAction extends HandledTransportAction { +public class ForwardADTaskTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(ForwardADTaskTransportAction.class); private final TransportService transportService; private final ADTaskManager adTaskManager; private final ADTaskCacheManager adTaskCacheManager; + private final ADIndexJobActionHandler indexJobHander; // ========================================================= // Fields below contains cache for realtime AD on coordinating @@ -64,7 +64,8 @@ public ForwardADTaskTransportAction( ADTaskManager adTaskManager, ADTaskCacheManager adTaskCacheManager, FeatureManager featureManager, - NodeStateManager stateManager + NodeStateManager stateManager, + ADIndexJobActionHandler indexJobHander ) { super(ForwardADTaskAction.NAME, transportService, actionFilters, ForwardADTaskRequest::new); this.adTaskManager = adTaskManager; @@ -72,10 +73,11 @@ public ForwardADTaskTransportAction( this.adTaskCacheManager = adTaskCacheManager; this.featureManager = featureManager; this.stateManager = stateManager; + this.indexJobHander = indexJobHander; } @Override - protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener listener) { + protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener listener) { ADTaskAction adTaskAction = request.getAdTaskAction(); AnomalyDetector detector = request.getDetector(); DateRange detectionDateRange = request.getDetectionDateRange(); @@ -107,7 +109,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener case START: // Start historical analysis for detector logger.debug("Received START action for detector {}", detectorId); - adTaskManager.startDetector(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { + indexJobHander.startConfig(detector, detectionDateRange, user, transportService, ActionListener.wrap(r -> { adTaskCacheManager.setDetectorTaskSlots(detector.getId(), availableTaskSlots); listener.onResponse(r); }, e -> listener.onFailure(e))); @@ -120,8 +122,8 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (!adTaskCacheManager.hasEntity(detectorId)) { adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); logger.info("Historical HC detector done, will remove from cache, detector id:{}", detectorId); - listener.onResponse(new AnomalyDetectorJobResponse(detectorId, 0, 0, 0, RestStatus.OK)); - ADTaskState state = !adTask.isEntityTask() && adTask.getError() != null ? ADTaskState.FAILED : ADTaskState.FINISHED; + listener.onResponse(new JobResponse(detectorId)); + TaskState state = !adTask.isEntityTask() && adTask.getError() != null ? TaskState.FAILED : TaskState.FINISHED; adTaskManager.setHCDetectorTaskDone(adTask, state, listener); } else { logger.debug("Run next entity for detector " + detectorId); @@ -132,11 +134,11 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTask.getParentTaskId(), ImmutableMap .of( - STATE_FIELD, - ADTaskState.RUNNING.name(), - TASK_PROGRESS_FIELD, + TimeSeriesTask.STATE_FIELD, + TaskState.RUNNING.name(), + TimeSeriesTask.TASK_PROGRESS_FIELD, adTaskManager.hcDetectorProgress(detectorId), - ERROR_FIELD, + TimeSeriesTask.ERROR_FIELD, adTask.getError() != null ? adTask.getError() : "" ) ); @@ -157,18 +159,18 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener if (adTask.isEntityTask()) { // AD task must be entity level task. adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (adTaskManager.isRetryableError(adTask.getError()) - && !adTaskCacheManager.exceedRetryLimit(adTask.getId(), adTask.getTaskId())) { + && !adTaskCacheManager.exceedRetryLimit(adTask.getConfigId(), adTask.getTaskId())) { // If retryable exception happens when run entity task, will push back entity to the end // of pending entities queue, then we can retry it later. - adTaskCacheManager.pushBackEntity(adTask.getTaskId(), adTask.getId(), entityValue); + adTaskCacheManager.pushBackEntity(adTask.getTaskId(), adTask.getConfigId(), entityValue); } else { // If exception is not retryable or exceeds retry limit, will remove this entity. - adTaskCacheManager.removeEntity(adTask.getId(), entityValue); + adTaskCacheManager.removeEntity(adTask.getConfigId(), entityValue); logger.warn("Entity task failed, task id: {}, entity: {}", adTask.getTaskId(), adTask.getEntity().toString()); } if (!adTaskCacheManager.hasEntity(detectorId)) { adTaskCacheManager.setDetectorTaskSlots(detectorId, 0); - adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.FINISHED, listener); + adTaskManager.setHCDetectorTaskDone(adTask, TaskState.FINISHED, listener); } else { logger.debug("scale task slots for PUSH_BACK_ENTITY, detector {} task {}", detectorId, adTask.getTaskId()); int taskSlots = adTaskCacheManager.scaleDownHCDetectorTaskSlots(detectorId, 1); @@ -176,7 +178,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener logger.debug("After scale down, only 1 task slot reserved for detector {}, run next entity", detectorId); adTaskManager.runNextEntityForHCADHistorical(adTask, transportService, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.ACCEPTED)); + listener.onResponse(new JobResponse(adTask.getTaskId())); } } else { logger.warn("Can only push back entity task"); @@ -193,7 +195,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.scaleUpDetectorTaskSlots(detectorId, newSlots); } } - listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(detector.getId())); break; case CANCEL: logger.debug("Received CANCEL action for detector {}", detectorId); @@ -204,9 +206,9 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener adTaskCacheManager.clearPendingEntities(detectorId); adTaskCacheManager.removeRunningEntity(detectorId, entityValue); if (!adTaskCacheManager.hasEntity(detectorId) || !adTask.isEntityTask()) { - adTaskManager.setHCDetectorTaskDone(adTask, ADTaskState.STOPPED, listener); + adTaskManager.setHCDetectorTaskDone(adTask, TaskState.STOPPED, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(adTask.getTaskId())); } else { listener.onFailure(new IllegalArgumentException("Only support cancel HC now")); } @@ -227,7 +229,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener for (String entity : staleRunningEntities) { adTaskManager.removeStaleRunningEntity(adTask, entity, transportService, listener); } - listener.onResponse(new AnomalyDetectorJobResponse(adTask.getTaskId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(adTask.getTaskId())); break; case CLEAN_CACHE: boolean historicalTask = adTask.isHistoricalTask(); @@ -249,7 +251,7 @@ protected void doExecute(Task task, ForwardADTaskRequest request, ActionListener stateManager.clear(detectorId); featureManager.clear(detectorId); } - listener.onResponse(new AnomalyDetectorJobResponse(detector.getId(), 0, 0, 0, RestStatus.OK)); + listener.onResponse(new JobResponse(detector.getId())); break; default: listener.onFailure(new OpenSearchStatusException("Unsupported AD task action " + adTaskAction, RestStatus.BAD_REQUEST)); diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java index c4232047d..c740ed24e 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class GetAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detectors/get"; public static final GetAnomalyDetectorAction INSTANCE = new GetAnomalyDetectorAction(); private GetAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java index e1532b816..652076531 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorResponse.java @@ -16,7 +16,6 @@ import org.opensearch.action.ActionResponse; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.EntityProfile; import org.opensearch.core.common.io.stream.StreamInput; @@ -24,6 +23,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.RestHandlerUtils; public class GetAnomalyDetectorResponse extends ActionResponse implements ToXContentObject { @@ -34,7 +34,7 @@ public class GetAnomalyDetectorResponse extends ActionResponse implements ToXCon private long primaryTerm; private long seqNo; private AnomalyDetector detector; - private AnomalyDetectorJob adJob; + private Job adJob; private ADTask realtimeAdTask; private ADTask historicalAdTask; private RestStatus restStatus; @@ -65,7 +65,7 @@ public GetAnomalyDetectorResponse(StreamInput in) throws IOException { detector = new AnomalyDetector(in); returnJob = in.readBoolean(); if (returnJob) { - adJob = new AnomalyDetectorJob(in); + adJob = new Job(in); } else { adJob = null; } @@ -89,7 +89,7 @@ public GetAnomalyDetectorResponse( long primaryTerm, long seqNo, AnomalyDetector detector, - AnomalyDetectorJob adJob, + Job adJob, boolean returnJob, ADTask realtimeAdTask, ADTask historicalAdTask, @@ -197,7 +197,7 @@ public DetectorProfile getDetectorProfile() { return detectorProfile; } - public AnomalyDetectorJob getAdJob() { + public Job getAdJob() { return adJob; } diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java index 473f247dd..c3bffb232 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportAction.java @@ -12,19 +12,11 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_GET_DETECTOR; -import static org.opensearch.ad.model.ADTaskType.ALL_DETECTOR_TASK_TYPES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; -import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; -import java.util.ArrayList; import java.util.Arrays; import java.util.EnumSet; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -34,25 +26,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; -import org.opensearch.action.get.MultiGetItemResponse; -import org.opensearch.action.get.MultiGetRequest; -import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; -import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.AnomalyDetectorProfileRunner; import org.opensearch.ad.EntityProfileRunner; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.CheckedConsumer; @@ -63,35 +51,31 @@ import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.tasks.Task; import org.opensearch.timeseries.Name; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; -import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; import com.google.common.collect.Sets; -public class GetAnomalyDetectorTransportAction extends HandledTransportAction { +public class GetAnomalyDetectorTransportAction extends + BaseGetConfigTransportAction { - private static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); + public static final Logger LOG = LogManager.getLogger(GetAnomalyDetectorTransportAction.class); - private final ClusterService clusterService; - private final Client client; - private final SecurityClientUtil clientUtil; private final Set allProfileTypeStrs; private final Set allProfileTypes; private final Set defaultDetectorProfileTypes; private final Set allEntityProfileTypeStrs; private final Set allEntityProfileTypes; private final Set defaultEntityProfileTypes; - private final NamedXContentRegistry xContentRegistry; - private final DiscoveryNodeFilterer nodeFilter; - private final TransportService transportService; - private volatile Boolean filterByEnabled; - private final ADTaskManager adTaskManager; @Inject public GetAnomalyDetectorTransportAction( @@ -105,10 +89,28 @@ public GetAnomalyDetectorTransportAction( NamedXContentRegistry xContentRegistry, ADTaskManager adTaskManager ) { - super(GetAnomalyDetectorAction.NAME, transportService, actionFilters, GetAnomalyDetectorRequest::new); - this.clusterService = clusterService; - this.client = client; - this.clientUtil = clientUtil; + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + adTaskManager, + GetAnomalyDetectorAction.NAME, + AnomalyDetector.class, + AnomalyDetector.PARSE_FIELD_NAME, + ADTaskType.ALL_DETECTOR_TASK_TYPES, + ADTaskType.AD_REALTIME_HC_DETECTOR.name(), + ADTaskType.AD_REALTIME_SINGLE_STREAM.name(), + ADTaskType.AD_HISTORICAL_HC_DETECTOR.name(), + ADTaskType.AD_HISTORICAL_SINGLE_STREAM.name(), + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + GetAnomalyDetectorResponse.class + ); + List allProfiles = Arrays.asList(DetectorProfileName.values()); this.allProfileTypes = EnumSet.copyOf(allProfiles); this.allProfileTypeStrs = getProfileListStrs(allProfiles); @@ -120,19 +122,12 @@ public GetAnomalyDetectorTransportAction( this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); this.defaultEntityProfileTypes = new HashSet(defaultEntityProfiles); - - this.xContentRegistry = xContentRegistry; - this.nodeFilter = nodeFilter; - filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); - this.transportService = transportService; - this.adTaskManager = adTaskManager; } @Override - protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionListener actionListener) { - String detectorID = request.getDetectorID(); - User user = getUserContext(client); + protected void doExecute(Task task, GetConfigRequest request, ActionListener actionListener) { + String detectorID = request.getConfigID(); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_DETECTOR); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { resolveUserAndExecute( @@ -143,7 +138,9 @@ protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionLis (anomalyDetector) -> getExecute(request, listener), client, clusterService, - xContentRegistry + xContentRegistry, + GetAnomalyDetectorResponse.class, + AnomalyDetector.class ); } catch (Exception e) { LOG.error(e); @@ -151,228 +148,6 @@ protected void doExecute(Task task, GetAnomalyDetectorRequest request, ActionLis } } - protected void getExecute(GetAnomalyDetectorRequest request, ActionListener listener) { - String detectorID = request.getDetectorID(); - String typesStr = request.getTypeStr(); - String rawPath = request.getRawPath(); - Entity entity = request.getEntity(); - boolean all = request.isAll(); - boolean returnJob = request.isReturnJob(); - boolean returnTask = request.isReturnTask(); - - try { - if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { - if (entity != null) { - Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); - EntityProfileRunner profileRunner = new EntityProfileRunner( - client, - clientUtil, - xContentRegistry, - AnomalyDetectorSettings.NUM_MIN_SAMPLES - ); - profileRunner - .profile( - detectorID, - entity, - entityProfilesToCollect, - ActionListener - .wrap( - profile -> { - listener - .onResponse( - new GetAnomalyDetectorResponse( - 0, - null, - 0, - 0, - null, - null, - false, - null, - null, - false, - null, - null, - profile, - true - ) - ); - }, - e -> listener.onFailure(e) - ) - ); - } else { - Set profilesToCollect = getProfilesToCollect(typesStr, all); - AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( - client, - clientUtil, - xContentRegistry, - nodeFilter, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - transportService, - adTaskManager - ); - profileRunner.profile(detectorID, getProfileActionListener(listener), profilesToCollect); - } - } else { - if (returnTask) { - adTaskManager.getAndExecuteOnLatestADTasks(detectorID, null, null, ALL_DETECTOR_TASK_TYPES, (taskList) -> { - Optional realtimeAdTask = Optional.empty(); - Optional historicalAdTask = Optional.empty(); - - if (taskList != null && taskList.size() > 0) { - Map adTasks = new HashMap<>(); - List duplicateAdTasks = new ArrayList<>(); - for (ADTask task : taskList) { - if (adTasks.containsKey(task.getTaskType())) { - LOG - .info( - "Found duplicate latest task of detector {}, task id: {}, task type: {}", - detectorID, - task.getTaskType(), - task.getTaskId() - ); - duplicateAdTasks.add(task); - continue; - } - adTasks.put(task.getTaskType(), task); - } - if (duplicateAdTasks.size() > 0) { - adTaskManager.resetLatestFlagAsFalse(duplicateAdTasks); - } - - if (adTasks.containsKey(ADTaskType.REALTIME_HC_DETECTOR.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.REALTIME_SINGLE_ENTITY.name())) { - realtimeAdTask = Optional.ofNullable(adTasks.get(ADTaskType.REALTIME_SINGLE_ENTITY.name())); - } - if (adTasks.containsKey(ADTaskType.HISTORICAL_HC_DETECTOR.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_HC_DETECTOR.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL_SINGLE_ENTITY.name())); - } else if (adTasks.containsKey(ADTaskType.HISTORICAL.name())) { - historicalAdTask = Optional.ofNullable(adTasks.get(ADTaskType.HISTORICAL.name())); - } - } - getDetectorAndJob(detectorID, returnJob, returnTask, realtimeAdTask, historicalAdTask, listener); - }, transportService, true, 2, listener); - } else { - getDetectorAndJob(detectorID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); - } - } - } catch (Exception e) { - LOG.error(e); - listener.onFailure(e); - } - } - - private void getDetectorAndJob( - String detectorID, - boolean returnJob, - boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - ActionListener listener - ) { - MultiGetRequest.Item adItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, detectorID); - MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); - if (returnJob) { - MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, detectorID); - multiGetRequest.add(adJobItem); - } - client.multiGet(multiGetRequest, onMultiGetResponse(listener, returnJob, returnTask, realtimeAdTask, historicalAdTask, detectorID)); - } - - private ActionListener onMultiGetResponse( - ActionListener listener, - boolean returnJob, - boolean returnTask, - Optional realtimeAdTask, - Optional historicalAdTask, - String detectorId - ) { - return new ActionListener() { - @Override - public void onResponse(MultiGetResponse multiGetResponse) { - MultiGetItemResponse[] responses = multiGetResponse.getResponses(); - AnomalyDetector detector = null; - AnomalyDetectorJob adJob = null; - String id = null; - long version = 0; - long seqNo = 0; - long primaryTerm = 0; - - for (MultiGetItemResponse response : responses) { - if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { - if (response.getResponse() == null || !response.getResponse().isExists()) { - listener.onFailure(new OpenSearchStatusException(FAIL_TO_FIND_CONFIG_MSG + detectorId, RestStatus.NOT_FOUND)); - return; - } - id = response.getId(); - version = response.getResponse().getVersion(); - primaryTerm = response.getResponse().getPrimaryTerm(); - seqNo = response.getResponse().getSeqNo(); - if (!response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - detector = parser.namedObject(AnomalyDetector.class, AnomalyDetector.PARSE_FIELD_NAME, null); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - - if (CommonName.JOB_INDEX.equals(response.getIndex())) { - if (response.getResponse() != null - && response.getResponse().isExists() - && !response.getResponse().isSourceEmpty()) { - try ( - XContentParser parser = RestHandlerUtils - .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - adJob = AnomalyDetectorJob.parse(parser); - } catch (Exception e) { - String message = "Failed to parse detector job " + detectorId; - listener.onFailure(buildInternalServerErrorResponse(e, message)); - return; - } - } - } - } - listener - .onResponse( - new GetAnomalyDetectorResponse( - version, - id, - primaryTerm, - seqNo, - detector, - adJob, - returnJob, - realtimeAdTask.orElse(null), - historicalAdTask.orElse(null), - returnTask, - RestStatus.OK, - null, - null, - false - ) - ); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }; - } - private ActionListener getProfileActionListener(ActionListener listener) { return ActionListener.wrap(new CheckedConsumer() { @Override @@ -385,11 +160,6 @@ public void accept(DetectorProfile profile) throws Exception { }, exception -> { listener.onFailure(exception); }); } - private OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { - LOG.error(errorMsg, e); - return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); - } - /** * * @param typesStr a list of input profile types separated by comma @@ -429,4 +199,106 @@ private Set getEntityProfilesToCollect(String typesStr, boole private Set getProfileListStrs(List profileList) { return profileList.stream().map(profile -> profile.getName()).collect(Collectors.toSet()); } + + @Override + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) { + if (tasks.containsKey(ADTaskType.HISTORICAL.name())) { + historicalAdTask = Optional.ofNullable(tasks.get(ADTaskType.HISTORICAL.name())); + } + } + + @Override + protected void getExecuteProfile( + GetConfigRequest request, + Entity entity, + String typesStr, + boolean all, + String configId, + ActionListener listener + ) { + if (entity != null) { + Set entityProfilesToCollect = getEntityProfilesToCollect(typesStr, all); + EntityProfileRunner profileRunner = new EntityProfileRunner( + client, + clientUtil, + xContentRegistry, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + profileRunner + .profile( + configId, + entity, + entityProfilesToCollect, + ActionListener + .wrap( + profile -> { + listener + .onResponse( + new GetAnomalyDetectorResponse( + 0, + null, + 0, + 0, + null, + null, + false, + null, + null, + false, + null, + null, + profile, + true + ) + ); + }, + e -> listener.onFailure(e) + ) + ); + } else { + Set profilesToCollect = getProfilesToCollect(typesStr, all); + AnomalyDetectorProfileRunner profileRunner = new AnomalyDetectorProfileRunner( + client, + clientUtil, + xContentRegistry, + nodeFilter, + TimeSeriesSettings.NUM_MIN_SAMPLES, + transportService, + taskManager + ); + profileRunner.profile(configId, getProfileActionListener(listener), profilesToCollect); + } + } + + @Override + protected GetAnomalyDetectorResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + AnomalyDetector config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus + ) { + return new GetAnomalyDetectorResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + RestStatus.OK, + null, + null, + false + ); + } } diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java index 9ee038336..56103dfc9 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class IndexAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/write"; public static final IndexAnomalyDetectorAction INSTANCE = new IndexAnomalyDetectorAction(); private IndexAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java index 572e847f9..6a4bb6d1d 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorRequest.java @@ -34,6 +34,9 @@ public class IndexAnomalyDetectorRequest extends ActionRequest { private Integer maxSingleEntityAnomalyDetectors; private Integer maxMultiEntityAnomalyDetectors; private Integer maxAnomalyFeatures; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private Integer maxCategoricalFields; public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -47,6 +50,7 @@ public IndexAnomalyDetectorRequest(StreamInput in) throws IOException { maxSingleEntityAnomalyDetectors = in.readInt(); maxMultiEntityAnomalyDetectors = in.readInt(); maxAnomalyFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); } public IndexAnomalyDetectorRequest( @@ -59,7 +63,8 @@ public IndexAnomalyDetectorRequest( TimeValue requestTimeout, Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, - Integer maxAnomalyFeatures + Integer maxAnomalyFeatures, + Integer maxCategoricalFields ) { super(); this.detectorID = detectorID; @@ -72,6 +77,7 @@ public IndexAnomalyDetectorRequest( this.maxSingleEntityAnomalyDetectors = maxSingleEntityAnomalyDetectors; this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; this.maxAnomalyFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; } public String getDetectorID() { @@ -114,6 +120,10 @@ public Integer getMaxAnomalyFeatures() { return maxAnomalyFeatures; } + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -127,6 +137,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxSingleEntityAnomalyDetectors); out.writeInt(maxMultiEntityAnomalyDetectors); out.writeInt(maxAnomalyFeatures); + out.writeInt(maxCategoricalFields); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java index 06018ae6c..9dbb9492d 100644 --- a/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportAction.java @@ -13,10 +13,9 @@ import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_CREATE_DETECTOR; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_UPDATE_DETECTOR; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; -import static org.opensearch.timeseries.util.ParseUtils.getDetector; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.timeseries.util.ParseUtils.getConfig; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; import java.util.List; @@ -29,13 +28,11 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -48,7 +45,10 @@ import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; public class IndexAnomalyDetectorTransportAction extends HandledTransportAction { @@ -86,14 +86,14 @@ public IndexAnomalyDetectorTransportAction( this.xContentRegistry = xContentRegistry; this.adTaskManager = adTaskManager; this.searchFeatureDao = searchFeatureDao; - filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); this.settings = settings; } @Override protected void doExecute(Task task, IndexAnomalyDetectorRequest request, ActionListener actionListener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); String detectorId = request.getDetectorID(); RestRequest.Method method = request.getMethod(); String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_DETECTOR : FAIL_TO_CREATE_DETECTOR; @@ -126,7 +126,18 @@ private void resolveUserAndExecute( boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; // Update detector request, check if user has permissions to update the detector // Get detector and verify backend roles - getDetector(requestedUser, detectorId, listener, function, client, clusterService, xContentRegistry, filterByBackendRole); + getConfig( + requestedUser, + detectorId, + listener, + function, + client, + clusterService, + xContentRegistry, + filterByBackendRole, + IndexAnomalyDetectorResponse.class, + AnomalyDetector.class + ); } else { // Create Detector. No need to get current detector. function.accept(null); @@ -154,6 +165,7 @@ protected void adExecute( Integer maxSingleEntityAnomalyDetectors = request.getMaxSingleEntityAnomalyDetectors(); Integer maxMultiEntityAnomalyDetectors = request.getMaxMultiEntityAnomalyDetectors(); Integer maxAnomalyFeatures = request.getMaxAnomalyFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); storedContext.restore(); checkIndicesAndExecute(detector.getIndices(), () -> { @@ -165,7 +177,6 @@ protected void adExecute( client, clientUtil, transportService, - listener, anomalyDetectionIndices, detectorId, seqNo, @@ -176,6 +187,7 @@ protected void adExecute( maxSingleEntityAnomalyDetectors, maxMultiEntityAnomalyDetectors, maxAnomalyFeatures, + maxCategoricalFields, method, xContentRegistry, detectorUser, @@ -183,7 +195,7 @@ protected void adExecute( searchFeatureDao, settings ); - indexAnomalyDetectorActionHandler.start(); + indexAnomalyDetectorActionHandler.start(listener); }, listener); } diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java index c90ecc446..5ae8d6c35 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class PreviewAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/preview"; public static final PreviewAnomalyDetectorAction INSTANCE = new PreviewAnomalyDetectorAction(); private PreviewAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java index 5d6bdd193..7d9f6a720 100644 --- a/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportAction.java @@ -12,11 +12,10 @@ package org.opensearch.ad.transport; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_PREVIEW_DETECTOR; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ANOMALY_FEATURES; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -35,7 +34,6 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.AnomalyDetectorRunner; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; @@ -51,11 +49,13 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.ClientException; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.ParseUtils; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; @@ -68,7 +68,7 @@ public class PreviewAnomalyDetectorTransportAction extends private final NamedXContentRegistry xContentRegistry; private volatile Integer maxAnomalyFeatures; private volatile Boolean filterByEnabled; - private final ADCircuitBreakerService adCircuitBreakerService; + private final CircuitBreakerService adCircuitBreakerService; private Semaphore lock; @Inject @@ -80,7 +80,7 @@ public PreviewAnomalyDetectorTransportAction( Client client, AnomalyDetectorRunner anomalyDetectorRunner, NamedXContentRegistry xContentRegistry, - ADCircuitBreakerService adCircuitBreakerService + CircuitBreakerService adCircuitBreakerService ) { super(PreviewAnomalyDetectorAction.NAME, transportService, actionFilters, PreviewAnomalyDetectorRequest::new); this.clusterService = clusterService; @@ -89,8 +89,8 @@ public PreviewAnomalyDetectorTransportAction( this.xContentRegistry = xContentRegistry; maxAnomalyFeatures = MAX_ANOMALY_FEATURES.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ANOMALY_FEATURES, it -> maxAnomalyFeatures = it); - filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); this.adCircuitBreakerService = adCircuitBreakerService; this.lock = new Semaphore(MAX_CONCURRENT_PREVIEW.get(settings), true); clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CONCURRENT_PREVIEW, it -> { lock = new Semaphore(it); }); @@ -103,7 +103,7 @@ protected void doExecute( ActionListener actionListener ) { String detectorId = request.getId(); - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_PREVIEW_DETECTOR); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { resolveUserAndExecute( @@ -114,7 +114,9 @@ protected void doExecute( (anomalyDetector) -> previewExecute(request, context, listener), client, clusterService, - xContentRegistry + xContentRegistry, + PreviewAnomalyDetectorResponse.class, + AnomalyDetector.class ); } catch (Exception e) { logger.error(e); diff --git a/src/main/java/org/opensearch/ad/transport/ProfileAction.java b/src/main/java/org/opensearch/ad/transport/ProfileAction.java index 291dd0982..559b6a892 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileAction.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileAction.java @@ -12,14 +12,14 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; /** * Profile transport action */ public class ProfileAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detectors/profile"; public static final ProfileAction INSTANCE = new ProfileAction(); /** diff --git a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java index 9517f6add..47fe5b901 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileNodeResponse.java @@ -17,13 +17,13 @@ import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.ModelProfile; /** * Profile response on a node diff --git a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java index 11ba28163..66e2cb90b 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileResponse.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileResponse.java @@ -21,13 +21,13 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.ClusterName; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentFragment; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.model.ModelProfile; /** * This class consists of the aggregated responses from the nodes diff --git a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java index e05251f2f..8cd5dd2cd 100644 --- a/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ProfileTransportAction.java @@ -11,7 +11,7 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; import java.io.IOException; import java.util.List; @@ -23,26 +23,29 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.caching.ADPriorityCache; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.model.ModelProfile; import org.opensearch.transport.TransportService; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + /** * This class contains the logic to extract the stats from the nodes */ public class ProfileTransportAction extends TransportNodesAction { private static final Logger LOG = LogManager.getLogger(ProfileTransportAction.class); - private ModelManager modelManager; + private ADModelManager modelManager; private FeatureManager featureManager; - private CacheProvider cacheProvider; + private CacheProvider cacheProvider; // the number of models to return. Defaults to 10. private volatile int numModelsToReturn; @@ -64,9 +67,9 @@ public ProfileTransportAction( ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, - ModelManager modelManager, + ADModelManager modelManager, FeatureManager featureManager, - CacheProvider cacheProvider, + CacheProvider cacheProvider, Settings settings ) { super( @@ -83,8 +86,8 @@ public ProfileTransportAction( this.modelManager = modelManager; this.featureManager = featureManager; this.cacheProvider = cacheProvider; - this.numModelsToReturn = MAX_MODEL_SIZE_PER_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); + this.numModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.numModelsToReturn = it); } @Override diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java index 147ff74cb..b38a088eb 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFPollingAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcfpolling"; public static final RCFPollingAction INSTANCE = new RCFPollingAction(); private RCFPollingAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java index 49e6f0153..35ed245b3 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFPollingTransportAction.java @@ -20,9 +20,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -31,7 +29,9 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; @@ -48,7 +48,7 @@ public class RCFPollingTransportAction extends HandledTransportAction rcfNode = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(rcfModelID); + Optional rcfNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); if (!rcfNode.isPresent()) { listener.onFailure(new TimeSeriesException(adID, NO_NODE_FOUND_MSG)); return; diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java index 3480e880a..f551f97df 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class RCFResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "rcf/result"; public static final RCFResultAction INSTANCE = new RCFResultAction(); private RCFResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java index f9d63365c..e0e54bae3 100644 --- a/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/RCFResultTransportAction.java @@ -21,34 +21,34 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.stats.ADStats; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.inject.Inject; import org.opensearch.tasks.Task; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; import org.opensearch.transport.TransportService; public class RCFResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(RCFResultTransportAction.class); - private ModelManager manager; - private ADCircuitBreakerService adCircuitBreakerService; + private ADModelManager manager; + private CircuitBreakerService adCircuitBreakerService; private HashRing hashRing; - private ADStats adStats; + private Stats adStats; @Inject public RCFResultTransportAction( ActionFilters actionFilters, TransportService transportService, - ModelManager manager, - ADCircuitBreakerService adCircuitBreakerService, + ADModelManager manager, + CircuitBreakerService adCircuitBreakerService, HashRing hashRing, - ADStats adStats + Stats adStats ) { super(RCFResultAction.NAME, transportService, actionFilters, RCFResultRequest::new); this.manager = manager; @@ -69,7 +69,7 @@ protected void doExecute(Task task, RCFResultRequest request, ActionListener { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "tasks/search"; public static final SearchADTasksAction INSTANCE = new SearchADTasksAction(); private SearchADTasksAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java index c15ece9ab..90ae6cede 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/search"; public static final SearchAnomalyDetectorAction INSTANCE = new SearchAnomalyDetectorAction(); private SearchAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java index 3f4f7c2fc..b784e4322 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyDetectorInfoAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/info"; public static final SearchAnomalyDetectorInfoAction INSTANCE = new SearchAnomalyDetectorInfoAction(); private SearchAnomalyDetectorInfoAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java index 7e0178393..e2a5969bd 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchAnomalyResultAction.java @@ -13,11 +13,11 @@ import org.opensearch.action.ActionType; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/search"; public static final SearchAnomalyResultAction INSTANCE = new SearchAnomalyResultAction(); private SearchAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java index ee89c4179..8956eeb1d 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class SearchTopAnomalyResultAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "result/topAnomalies"; public static final SearchTopAnomalyResultAction INSTANCE = new SearchTopAnomalyResultAction(); private SearchTopAnomalyResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java index 86ad7941a..2268b1dc1 100644 --- a/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/SearchTopAnomalyResultTransportAction.java @@ -64,6 +64,7 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; @@ -219,7 +220,7 @@ public SearchTopAnomalyResultTransportAction( @Override protected void doExecute(Task task, SearchTopAnomalyResultRequest request, ActionListener listener) { - GetAnomalyDetectorRequest getAdRequest = new GetAnomalyDetectorRequest( + GetConfigRequest getAdRequest = new GetConfigRequest( request.getId(), // The default version value used in org.opensearch.rest.action.RestActions.parseVersion() -3L, diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java index 3c1f53d9d..bd08656f1 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class StatsAnomalyDetectorAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/stats"; public static final StatsAnomalyDetectorAction INSTANCE = new StatsAnomalyDetectorAction(); private StatsAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java index 4a233fc62..6dbc91068 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorResponse.java @@ -14,7 +14,6 @@ import java.io.IOException; import org.opensearch.action.ActionResponse; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; diff --git a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java index caf4bd42a..e706b8382 100644 --- a/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportAction.java @@ -28,9 +28,6 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorType; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.ADStatsResponse; -import org.opensearch.ad.util.MultiResponsesDelegateActionListener; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -43,6 +40,8 @@ import org.opensearch.tasks.Task; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; import org.opensearch.transport.TransportService; public class StatsAnomalyDetectorTransportAction extends HandledTransportAction { @@ -50,7 +49,7 @@ public class StatsAnomalyDetectorTransportAction extends HandledTransportAction< private final Logger logger = LogManager.getLogger(StatsAnomalyDetectorTransportAction.class); private final Client client; - private final ADStats adStats; + private final Stats adStats; private final ClusterService clusterService; @Inject @@ -58,7 +57,7 @@ public StatsAnomalyDetectorTransportAction( TransportService transportService, ActionFilters actionFilters, Client client, - ADStats adStats, + Stats adStats, ClusterService clusterService ) { @@ -128,8 +127,8 @@ private void getClusterStats( ) { ADStatsResponse adStatsResponse = new ADStatsResponse(); if ((adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()) - || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()) + || adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) && clusterService.state().getRoutingTable().hasIndex(CommonName.CONFIG_INDEX)) { TermsAggregationBuilder termsAgg = AggregationBuilders.terms(DETECTOR_TYPE_AGG).field(AnomalyDetector.DETECTOR_TYPE_FIELD); @@ -158,11 +157,11 @@ private void getClusterStats( if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.DETECTOR_COUNT.getName())) { adStats.getStat(StatNames.DETECTOR_COUNT.getName()).setValue(totalDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())) { + adStats.getStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()).setValue(totalSingleEntityDetectors); } - if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName())) { - adStats.getStat(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); + if (adStatsRequest.getStatsToBeRetrieved().contains(StatNames.HC_DETECTOR_COUNT.getName())) { + adStats.getStat(StatNames.HC_DETECTOR_COUNT.getName()).setValue(totalMultiEntityDetectors); } adStatsResponse.setClusterStats(getClusterStatsMap(adStatsRequest)); listener.onResponse(adStatsResponse); diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java index 5c7182920..15f617e78 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorAction.java @@ -12,15 +12,16 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; -public class StopDetectorAction extends ActionType { +public class StopDetectorAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "detector/stop"; public static final StopDetectorAction INSTANCE = new StopDetectorAction(); private StopDetectorAction() { - super(NAME, StopDetectorResponse::new); + super(NAME, StopConfigResponse::new); } } diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java index f84d4114e..03d02763c 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/StopDetectorTransportAction.java @@ -27,10 +27,13 @@ import org.opensearch.common.inject.Inject; import org.opensearch.tasks.Task; import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; -public class StopDetectorTransportAction extends HandledTransportAction { +public class StopDetectorTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(StopDetectorTransportAction.class); @@ -44,19 +47,19 @@ public StopDetectorTransportAction( ActionFilters actionFilters, Client client ) { - super(StopDetectorAction.NAME, transportService, actionFilters, StopDetectorRequest::new); + super(StopDetectorAction.NAME, transportService, actionFilters, StopConfigRequest::new); this.client = client; this.nodeFilter = nodeFilter; } @Override - protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { - StopDetectorRequest request = StopDetectorRequest.fromActionRequest(actionRequest); - String adID = request.getAdID(); + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String adID = request.getConfigID(); try { DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(adID, dataNodes); - client.execute(DeleteModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + client.execute(DeleteADModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { if (response.hasFailures()) { LOG.warn("Cannot delete all models of detector {}", adID); for (FailedNodeException failedNodeException : response.failures()) { @@ -64,14 +67,14 @@ protected void doExecute(Task task, ActionRequest actionRequest, ActionListener< } // if customers are using an updated detector and we haven't deleted old // checkpoints, customer would have trouble - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); } else { LOG.info("models of detector {} get deleted", adID); - listener.onResponse(new StopDetectorResponse(true)); + listener.onResponse(new StopConfigResponse(true)); } }, exception -> { LOG.error(new ParameterizedMessage("Deletion of detector [{}] has exception.", adID), exception); - listener.onResponse(new StopDetectorResponse(false)); + listener.onResponse(new StopConfigResponse(false)); })); } catch (Exception e) { LOG.error(FAIL_TO_STOP_DETECTOR + " " + adID, e); diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java index 1561c08dc..f8a81252a 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ThresholdResultAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. - public static final String NAME = CommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; + public static final String NAME = ADCommonValue.INTERNAL_ACTION_PREFIX + "threshold/result"; public static final ThresholdResultAction INSTANCE = new ThresholdResultAction(); private ThresholdResultAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java index 9e292b676..a79ef0814 100644 --- a/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ThresholdResultTransportAction.java @@ -16,7 +16,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.common.inject.Inject; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -24,10 +24,10 @@ public class ThresholdResultTransportAction extends HandledTransportAction { private static final Logger LOG = LogManager.getLogger(ThresholdResultTransportAction.class); - private ModelManager manager; + private ADModelManager manager; @Inject - public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ModelManager manager) { + public ThresholdResultTransportAction(ActionFilters actionFilters, TransportService transportService, ADModelManager manager) { super(ThresholdResultAction.NAME, transportService, actionFilters, ThresholdResultRequest::new); this.manager = manager; } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java index 432166ac2..9af0f0403 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorAction.java @@ -12,11 +12,11 @@ package org.opensearch.ad.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; +import org.opensearch.ad.constant.ADCommonValue; public class ValidateAnomalyDetectorAction extends ActionType { - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/validate"; public static final ValidateAnomalyDetectorAction INSTANCE = new ValidateAnomalyDetectorAction(); private ValidateAnomalyDetectorAction() { diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java index 3ee1f0a6e..203e75f7c 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorRequest.java @@ -28,6 +28,9 @@ public class ValidateAnomalyDetectorRequest extends ActionRequest { private final Integer maxMultiEntityAnomalyDetectors; private final Integer maxAnomalyFeatures; private final TimeValue requestTimeout; + // added during refactoring for forecasting. It is fine we add a new field + // since the request is handled by the same node. + private Integer maxCategoricalFields; public ValidateAnomalyDetectorRequest(StreamInput in) throws IOException { super(in); @@ -37,6 +40,7 @@ public ValidateAnomalyDetectorRequest(StreamInput in) throws IOException { maxMultiEntityAnomalyDetectors = in.readInt(); maxAnomalyFeatures = in.readInt(); requestTimeout = in.readTimeValue(); + maxCategoricalFields = in.readInt(); } public ValidateAnomalyDetectorRequest( @@ -45,7 +49,8 @@ public ValidateAnomalyDetectorRequest( Integer maxSingleEntityAnomalyDetectors, Integer maxMultiEntityAnomalyDetectors, Integer maxAnomalyFeatures, - TimeValue requestTimeout + TimeValue requestTimeout, + Integer maxCategoricalFields ) { this.detector = detector; this.validationType = validationType; @@ -53,6 +58,7 @@ public ValidateAnomalyDetectorRequest( this.maxMultiEntityAnomalyDetectors = maxMultiEntityAnomalyDetectors; this.maxAnomalyFeatures = maxAnomalyFeatures; this.requestTimeout = requestTimeout; + this.maxCategoricalFields = maxCategoricalFields; } @Override @@ -64,6 +70,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxMultiEntityAnomalyDetectors); out.writeInt(maxAnomalyFeatures); out.writeTimeValue(requestTimeout); + out.writeInt(maxCategoricalFields); } @Override @@ -94,4 +101,8 @@ public Integer getMaxAnomalyFeatures() { public TimeValue getRequestTimeout() { return requestTimeout; } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } } diff --git a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java index ecd0ca07c..4758e8220 100644 --- a/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java +++ b/src/main/java/org/opensearch/ad/transport/ValidateAnomalyDetectorTransportAction.java @@ -11,9 +11,8 @@ package org.opensearch.ad.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.ParseUtils.checkFilterByBackendRoles; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; import java.time.Clock; import java.util.HashMap; @@ -28,13 +27,11 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -48,10 +45,13 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; public class ValidateAnomalyDetectorTransportAction extends @@ -86,8 +86,8 @@ public ValidateAnomalyDetectorTransportAction( this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; this.anomalyDetectionIndices = anomalyDetectionIndices; - this.filterByEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.filterByEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); this.searchFeatureDao = searchFeatureDao; this.clock = Clock.systemUTC(); this.settings = settings; @@ -95,7 +95,7 @@ public ValidateAnomalyDetectorTransportAction( @Override protected void doExecute(Task task, ValidateAnomalyDetectorRequest request, ActionListener listener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); AnomalyDetector anomalyDetector = request.getDetector(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { resolveUserAndExecute(user, listener, () -> validateExecute(request, user, context, listener)); @@ -150,13 +150,13 @@ private void validateExecute( clusterService, client, clientUtil, - validateListener, anomalyDetectionIndices, detector, request.getRequestTimeout(), request.getMaxSingleEntityAnomalyDetectors(), request.getMaxMultiEntityAnomalyDetectors(), request.getMaxAnomalyFeatures(), + request.getMaxCategoricalFields(), RestRequest.Method.POST, xContentRegistry, user, @@ -166,7 +166,7 @@ private void validateExecute( settings ); try { - handler.start(); + handler.start(validateListener); } catch (Exception exception) { String errorMessage = String .format(Locale.ROOT, "Unknown exception caught while validating detector %s", request.getDetector()); diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..76c3069a2 --- /dev/null +++ b/src/main/java/org/opensearch/ad/transport/handler/ADIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ad.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.transport.ADResultBulkAction; +import org.opensearch.ad.transport.ADResultBulkRequest; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ADIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ADIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ADIndexMemoryPressureAwareResultHandler(Client client, ADIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java index 4831eae88..c1fccd50c 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java +++ b/src/main/java/org/opensearch/ad/transport/handler/ADSearchHandler.java @@ -12,9 +12,7 @@ package org.opensearch.ad.transport.handler; import static org.opensearch.ad.constant.ADCommonMessages.FAIL_TO_SEARCH; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; -import static org.opensearch.timeseries.util.ParseUtils.getUserContext; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.util.ParseUtils.isAdmin; import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; @@ -29,6 +27,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; +import org.opensearch.timeseries.util.ParseUtils; /** * Handle general search request, check user role and return search response. @@ -40,8 +39,8 @@ public class ADSearchHandler { public ADSearchHandler(Settings settings, ClusterService clusterService, Client client) { this.client = client; - filterEnabled = AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); + filterEnabled = AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterEnabled = it); } /** @@ -52,7 +51,7 @@ public ADSearchHandler(Settings settings, ClusterService clusterService, Client * @param actionListener action listerner */ public void search(SearchRequest request, ActionListener actionListener) { - User user = getUserContext(client); + User user = ParseUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_SEARCH); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { validateRole(request, user, listener); @@ -72,7 +71,7 @@ private void validateRole(SearchRequest request, User user, ActionListener { - private static final Logger LOG = LogManager.getLogger(MultiEntityResultHandler.class); - // package private for testing - static final String SUCCESS_SAVING_RESULT_MSG = "Result saved successfully."; - static final String CANNOT_SAVE_RESULT_ERR_MSG = "Cannot save results due to write block."; - - @Inject - public MultiEntityResultHandler( - Client client, - Settings settings, - ThreadPool threadPool, - ADIndexManagement anomalyDetectionIndices, - ClientUtil clientUtil, - IndexUtils indexUtils, - ClusterService clusterService - ) { - super( - client, - settings, - threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - clientUtil, - indexUtils, - clusterService - ); - } - - /** - * Execute the bulk request - * @param currentBulkRequest The bulk request - * @param listener callback after flushing - */ - public void flush(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { - listener.onFailure(new TimeSeriesException(CANNOT_SAVE_RESULT_ERR_MSG)); - return; - } - - try { - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { - if (initResponse.isAcknowledged()) { - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Creating result index with mappings call not acknowledged."); - listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); - } - }, exception -> { - if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { - // It is possible the index has been created while we sending the create request - bulk(currentBulkRequest, listener); - } else { - LOG.warn("Unexpected error creating result index", exception); - listener.onFailure(exception); - } - })); - } else { - bulk(currentBulkRequest, listener); - } - } catch (Exception e) { - LOG.warn("Error in bulking results", e); - listener.onFailure(e); - } - } - - private void bulk(ADResultBulkRequest currentBulkRequest, ActionListener listener) { - if (currentBulkRequest.numberOfActions() <= 0) { - listener.onFailure(new TimeSeriesException("no result to save")); - return; - } - client.execute(ADResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { - LOG.debug(SUCCESS_SAVING_RESULT_MSG); - listener.onResponse(response); - }, exception -> { - LOG.error("Error in bulking results", exception); - listener.onFailure(exception); - })); - } -} diff --git a/src/main/java/org/opensearch/ad/util/ClientUtil.java b/src/main/java/org/opensearch/ad/util/ClientUtil.java deleted file mode 100644 index d85d4fdf7..000000000 --- a/src/main/java/org/opensearch/ad/util/ClientUtil.java +++ /dev/null @@ -1,332 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.util; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; - -import java.util.List; -import java.util.Optional; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiConsumer; -import java.util.function.Function; - -import org.apache.logging.log4j.Logger; -import org.opensearch.OpenSearchException; -import org.opensearch.OpenSearchTimeoutException; -import org.opensearch.action.ActionFuture; -import org.opensearch.action.ActionListener; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionResponse; -import org.opensearch.action.ActionType; -import org.opensearch.action.LatchedActionListener; -import org.opensearch.action.TaskOperationFailure; -import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction; -import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; -import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; -import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksAction; -import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksRequest; -import org.opensearch.action.admin.cluster.node.tasks.list.ListTasksResponse; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.client.Client; -import org.opensearch.common.inject.Inject; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.tasks.Task; -import org.opensearch.tasks.TaskId; -import org.opensearch.tasks.TaskInfo; -import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.common.exception.InternalFailure; -import org.opensearch.timeseries.constant.CommonMessages; - -public class ClientUtil { - private volatile TimeValue requestTimeout; - private Client client; - private final Throttler throttler; - private ThreadPool threadPool; - - @Inject - public ClientUtil(Settings setting, Client client, Throttler throttler, ThreadPool threadPool) { - this.requestTimeout = REQUEST_TIMEOUT.get(setting); - this.client = client; - this.throttler = throttler; - this.threadPool = threadPool; - } - - /** - * Send a nonblocking request with a timeout and return response. Blocking is not allowed in a - * transport call context. See BaseFuture.blockingAllowed - * @param request request like index/search/get - * @param LOG log - * @param consumer functional interface to operate as a client request like client::get - * @param ActionRequest - * @param ActionResponse - * @return the response - * @throws OpenSearchTimeoutException when we cannot get response within time. - * @throws IllegalStateException when the waiting thread is interrupted - */ - public Optional timedRequest( - Request request, - Logger LOG, - BiConsumer> consumer - ) { - try { - AtomicReference respReference = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(1); - - consumer - .accept( - request, - new LatchedActionListener( - ActionListener - .wrap( - response -> { respReference.set(response); }, - exception -> { LOG.error("Cannot get response for request {}, error: {}", request, exception); } - ), - latch - ) - ); - - if (!latch.await(requestTimeout.getSeconds(), TimeUnit.SECONDS)) { - throw new OpenSearchTimeoutException("Cannot get response within time limit: " + request.toString()); - } - return Optional.ofNullable(respReference.get()); - } catch (InterruptedException e1) { - LOG.error(CommonMessages.WAIT_ERR_MSG); - throw new IllegalStateException(e1); - } - } - - /** - * Send an asynchronous request and handle response with the provided listener. - * @param ActionRequest - * @param ActionResponse - * @param request request body - * @param consumer request method, functional interface to operate as a client request like client::get - * @param listener needed to handle response - */ - public void asyncRequest( - Request request, - BiConsumer> consumer, - ActionListener listener - ) { - consumer - .accept( - request, - ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) - ); - } - - /** - * Execute a transport action and handle response with the provided listener. - * @param ActionRequest - * @param ActionResponse - * @param action transport action - * @param request request body - * @param listener needed to handle response - */ - public void execute( - ActionType action, - Request request, - ActionListener listener - ) { - client - .execute( - action, - request, - ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) - ); - } - - /** - * Send an synchronous request and handle response with the provided listener. - * - * @deprecated use asyncRequest with listener instead. - * - * @param ActionRequest - * @param ActionResponse - * @param request request body - * @param function request method, functional interface to operate as a client request like client::get - * @return the response - */ - @Deprecated - public Response syncRequest( - Request request, - Function> function - ) { - return function.apply(request).actionGet(requestTimeout); - } - - /** - * Send a nonblocking request with a timeout and return response. - * If there is already a query running on given detector, it will try to - * cancel the query. Otherwise it will add this query to the negative cache - * and then attach the AnomalyDetection specific header to the request. - * Once the request complete, it will be removed from the negative cache. - * @param ActionRequest - * @param ActionResponse - * @param request request like index/search/get - * @param LOG log - * @param consumer functional interface to operate as a client request like client::get - * @param detector Anomaly Detector - * @return the response - * @throws InternalFailure when there is already a query running - * @throws OpenSearchTimeoutException when we cannot get response within time. - * @throws IllegalStateException when the waiting thread is interrupted - */ - public Optional throttledTimedRequest( - Request request, - Logger LOG, - BiConsumer> consumer, - AnomalyDetector detector - ) { - - try { - String detectorId = detector.getId(); - if (!throttler.insertFilteredQuery(detectorId, request)) { - LOG.info("There is one query running for detectorId: {}. Trying to cancel the long running query", detectorId); - cancelRunningQuery(client, detectorId, LOG); - throw new InternalFailure(detector.getId(), "There is already a query running on AnomalyDetector"); - } - AtomicReference respReference = new AtomicReference<>(); - final CountDownLatch latch = new CountDownLatch(1); - - try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { - assert context != null; - threadPool.getThreadContext().putHeader(Task.X_OPAQUE_ID, ADCommonName.ANOMALY_DETECTOR + ":" + detectorId); - consumer.accept(request, new LatchedActionListener(ActionListener.wrap(response -> { - // clear negative cache - throttler.clearFilteredQuery(detectorId); - respReference.set(response); - }, exception -> { - // clear negative cache - throttler.clearFilteredQuery(detectorId); - LOG.error("Cannot get response for request {}, error: {}", request, exception); - }), latch)); - } catch (Exception e) { - LOG.error("Failed to process the request for detectorId: {}.", detectorId); - throttler.clearFilteredQuery(detectorId); - throw e; - } - - if (!latch.await(requestTimeout.getSeconds(), TimeUnit.SECONDS)) { - throw new OpenSearchTimeoutException("Cannot get response within time limit: " + request.toString()); - } - return Optional.ofNullable(respReference.get()); - } catch (InterruptedException e1) { - LOG.error(CommonMessages.WAIT_ERR_MSG); - throw new IllegalStateException(e1); - } - } - - /** - * Check if there is running query on given detector - * @param detector Anomaly Detector - * @return true if given detector has a running query else false - */ - public boolean hasRunningQuery(AnomalyDetector detector) { - return throttler.getFilteredQuery(detector.getId()).isPresent(); - } - - /** - * Cancel long running query for given detectorId - * @param client OpenSearch client - * @param detectorId Anomaly Detector Id - * @param LOG Logger - */ - private void cancelRunningQuery(Client client, String detectorId, Logger LOG) { - ListTasksRequest listTasksRequest = new ListTasksRequest(); - listTasksRequest.setActions("*search*"); - client - .execute( - ListTasksAction.INSTANCE, - listTasksRequest, - ActionListener.wrap(response -> { onListTaskResponse(response, detectorId, LOG); }, exception -> { - LOG.error("List Tasks failed.", exception); - throw new InternalFailure(detectorId, "Failed to list current tasks", exception); - }) - ); - } - - /** - * Helper function to handle ListTasksResponse - * @param listTasksResponse ListTasksResponse - * @param detectorId Anomaly Detector Id - * @param LOG Logger - */ - private void onListTaskResponse(ListTasksResponse listTasksResponse, String detectorId, Logger LOG) { - List tasks = listTasksResponse.getTasks(); - TaskId matchedParentTaskId = null; - TaskId matchedSingleTaskId = null; - for (TaskInfo task : tasks) { - if (!task.getHeaders().isEmpty() - && task.getHeaders().get(Task.X_OPAQUE_ID).equals(ADCommonName.ANOMALY_DETECTOR + ":" + detectorId)) { - if (!task.getParentTaskId().equals(TaskId.EMPTY_TASK_ID)) { - // we found the parent task, don't need to check more - matchedParentTaskId = task.getParentTaskId(); - break; - } else { - // we found one task, keep checking other tasks - matchedSingleTaskId = task.getTaskId(); - } - } - } - // case 1: given detectorId is not in current task list - if (matchedParentTaskId == null && matchedSingleTaskId == null) { - // log and then clear negative cache - LOG.info("Couldn't find task for detectorId: {}. Clean this entry from Throttler", detectorId); - throttler.clearFilteredQuery(detectorId); - return; - } - // case 2: we can find the task for given detectorId - CancelTasksRequest cancelTaskRequest = new CancelTasksRequest(); - if (matchedParentTaskId != null) { - cancelTaskRequest.setParentTaskId(matchedParentTaskId); - LOG.info("Start to cancel task for parentTaskId: {}", matchedParentTaskId.toString()); - } else { - cancelTaskRequest.setTaskId(matchedSingleTaskId); - LOG.info("Start to cancel task for taskId: {}", matchedSingleTaskId.toString()); - } - - client - .execute( - CancelTasksAction.INSTANCE, - cancelTaskRequest, - ActionListener.wrap(response -> { onCancelTaskResponse(response, detectorId, LOG); }, exception -> { - LOG.error("Failed to cancel task for detectorId: " + detectorId, exception); - throw new InternalFailure(detectorId, "Failed to cancel current tasks", exception); - }) - ); - } - - /** - * Helper function to handle CancelTasksResponse - * @param cancelTasksResponse CancelTasksResponse - * @param detectorId Anomaly Detector Id - * @param LOG Logger - */ - private void onCancelTaskResponse(CancelTasksResponse cancelTasksResponse, String detectorId, Logger LOG) { - // todo: adding retry mechanism - List nodeFailures = cancelTasksResponse.getNodeFailures(); - List taskFailures = cancelTasksResponse.getTaskFailures(); - if (nodeFailures.isEmpty() && taskFailures.isEmpty()) { - LOG.info("Cancelling query for detectorId: {} succeeds. Clear entry from Throttler", detectorId); - throttler.clearFilteredQuery(detectorId); - return; - } - LOG.error("Failed to cancel task for detectorId: " + detectorId); - throw new InternalFailure(detectorId, "Failed to cancel current tasks due to node or task failures"); - } -} diff --git a/src/main/java/org/opensearch/ad/util/Throttler.java b/src/main/java/org/opensearch/ad/util/Throttler.java deleted file mode 100644 index 177b612a2..000000000 --- a/src/main/java/org/opensearch/ad/util/Throttler.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad.util; - -import java.time.Clock; -import java.time.Instant; -import java.util.AbstractMap; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; - -import org.opensearch.action.ActionRequest; - -/** - * Utility functions for throttling query. - */ -public class Throttler { - // negativeCache is used to reject search query if given detector already has one query running - // key is detectorId, value is an entry. Key is ActionRequest and value is the timestamp - private final ConcurrentHashMap> negativeCache; - private final Clock clock; - - public Throttler(Clock clock) { - this.negativeCache = new ConcurrentHashMap<>(); - this.clock = clock; - } - - /** - * This will be used when dependency injection directly/indirectly injects a Throttler object. Without this object, - * node start might fail due to not being able to find a Clock object. We removed Clock object association in - * https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/305 - */ - public Throttler() { - this(Clock.systemUTC()); - } - - /** - * Get negative cache value(ActionRequest, Instant) for given detector - * @param detectorId AnomalyDetector ID - * @return negative cache value(ActionRequest, Instant) - */ - public Optional> getFilteredQuery(String detectorId) { - return Optional.ofNullable(negativeCache.get(detectorId)); - } - - /** - * Insert the negative cache entry for given detector - * If key already exists, return false. Otherwise true. - * @param detectorId AnomalyDetector ID - * @param request ActionRequest - * @return true if key doesn't exist otherwise false. - */ - public synchronized boolean insertFilteredQuery(String detectorId, ActionRequest request) { - return negativeCache.putIfAbsent(detectorId, new AbstractMap.SimpleEntry<>(request, clock.instant())) == null; - } - - /** - * Clear the negative cache for given detector. - * @param detectorId AnomalyDetector ID - */ - public void clearFilteredQuery(String detectorId) { - negativeCache.remove(detectorId); - } -} diff --git a/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java new file mode 100644 index 000000000..c80a10f49 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ExecuteForecastResultResponseRecorder.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Optional; + +import org.opensearch.client.Client; +import org.opensearch.commons.authuser.User; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; + +public class ExecuteForecastResultResponseRecorder extends + ExecuteResultResponseRecorder { + + public ExecuteForecastResultResponseRecorder( + ForecastIndexManagement indexManagement, + ResultBulkIndexingHandler resultHandler, + ForecastTaskManager taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples + ) { + super( + indexManagement, + resultHandler, + taskManager, + nodeFilter, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + client, + nodeStateManager, + taskCacheManager, + rcfMinSamples, + ForecastIndex.RESULT, + AnalysisType.FORECAST + ); + } + + @Override + protected ForecastResult createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ) { + return new ForecastResult( + configId, + null, // no task id + new ArrayList(), + dataStartTime, + dataEndTime, + executeEndTime, + Instant.now(), + errorMessage, + Optional.empty(), // single-stream forecasters have no entity + user, + indexManagement.getSchemaVersion(resultIndex), + null // no model id + ); + } + + @Override + protected void updateRealtimeTask(ResultResponse response, String configId) { + if (taskManager.skipUpdateRealtimeTask(configId, response.getError())) { + return; + } + + delayedUpdate(response, configId); + } +} diff --git a/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java new file mode 100644 index 000000000..6dc7a5bc6 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ForecastJobProcessor.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast; + +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.JobProcessor; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastJobProcessor extends + JobProcessor { + + private static ForecastJobProcessor INSTANCE; + + public static ForecastJobProcessor getInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (JobProcessor.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new ForecastJobProcessor(); + return INSTANCE; + } + } + + private ForecastJobProcessor() { + // Singleton class, use getJobRunnerInstance method instead of constructor + super(AnalysisType.FORECAST, TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, ForecastResultAction.INSTANCE); + } + + public void registerSettings(Settings settings) { + super.registerSettings(settings, ForecastSettings.FORECAST_MAX_RETRY_FOR_END_RUN_EXCEPTION); + } + + @Override + protected ResultRequest createResultRequest(String configId, long start, long end) { + return new ForecastResultRequest(configId, start, end); + } +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java new file mode 100644 index 000000000..a3fbdab1d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheBuffer.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import java.time.Clock; +import java.time.Duration; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.CacheBuffer; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheBuffer extends + CacheBuffer { + + public ForecastCacheBuffer( + int minimumCapacity, + Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, + Duration modelTtl, + long memoryConsumptionPerEntity, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue, + String configId, + long intervalSecs + ) { + super( + 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + intervalSecs, + checkpointWriteQueue, + checkpointMaintainQueue, + configId, + intervalSecs, + Origin.REAL_TIME_FORECASTER + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java new file mode 100644 index 000000000..f93982cc2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastCacheProvider.java @@ -0,0 +1,14 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.caching; + +import org.opensearch.timeseries.caching.CacheProvider; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCacheProvider extends CacheProvider { + +} diff --git a/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java new file mode 100644 index 000000000..5b042b039 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/caching/ForecastPriorityCache.java @@ -0,0 +1,130 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.caching; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Optional; +import java.util.concurrent.Callable; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.caching.PriorityCache; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastPriorityCache extends + PriorityCache { + private ForecastCheckpointWriteWorker checkpointWriteQueue; + private ForecastCheckpointMaintainWorker checkpointMaintainQueue; + + public ForecastPriorityCache( + ForecastCheckpointDao checkpointDao, + int hcDedicatedCacheSize, + Setting checkpointTtl, + int maxInactiveStates, + MemoryTracker memoryTracker, + int numberOfTrees, + Clock clock, + ClusterService clusterService, + Duration modelTtl, + ThreadPool threadPool, + String threadPoolName, + int maintenanceFreqConstant, + Settings settings, + Setting checkpointSavingFreq, + ForecastCheckpointWriteWorker checkpointWriteQueue, + ForecastCheckpointMaintainWorker checkpointMaintainQueue + ) { + super( + checkpointDao, + hcDedicatedCacheSize, + checkpointTtl, + maxInactiveStates, + memoryTracker, + numberOfTrees, + clock, + clusterService, + modelTtl, + threadPool, + threadPoolName, + maintenanceFreqConstant, + settings, + checkpointSavingFreq, + Origin.REAL_TIME_FORECASTER, + FORECAST_DEDICATED_CACHE_SIZE, + FORECAST_MODEL_MAX_SIZE_PERCENTAGE + ); + + this.checkpointWriteQueue = checkpointWriteQueue; + this.checkpointMaintainQueue = checkpointMaintainQueue; + } + + @Override + protected ForecastCacheBuffer createEmptyCacheBuffer(Config config, long requiredMemory) { + return new ForecastCacheBuffer( + config.isHighCardinality() ? hcDedicatedCacheSize : 1, + clock, + memoryTracker, + checkpointIntervalHrs, + modelTtl, + requiredMemory, + checkpointWriteQueue, + checkpointMaintainQueue, + config.getId(), + config.getIntervalInSeconds() + ); + } + + @Override + protected Callable> createInactiveEntityCacheLoader(String modelId, String detectorId) { + return new Callable>() { + @Override + public ModelState call() { + return new ModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + Optional.empty(), + new ArrayDeque<>() + ); + } + }; + } + + @Override + protected boolean isDoorKeeperInCacheEnabled() { + return false; + } +} diff --git a/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java index 46de0c762..deb31cad7 100644 --- a/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java +++ b/src/main/java/org/opensearch/forecast/constant/ForecastCommonMessages.java @@ -35,6 +35,7 @@ public class ForecastCommonMessages { public static String FAIL_TO_FIND_FORECASTER_MSG = "Can not find forecaster with id: "; public static final String FORECASTER_ID_MISSING_MSG = "Forecaster ID is missing"; public static final String INVALID_TIMESTAMP_ERR_MSG = "timestamp is invalid"; + public static String FAIL_TO_GET_FORECASTER = "Fail to get forecaster"; // ====================================== // Security @@ -45,10 +46,17 @@ public class ForecastCommonMessages { // ====================================== // Used for custom forecast result index // ====================================== + public static String CAN_NOT_FIND_RESULT_INDEX = "Can't find result index "; public static String INVALID_RESULT_INDEX_PREFIX = "Result index must start with " + CUSTOM_RESULT_INDEX_PREFIX; // ====================================== // Task // ====================================== public static String FORECASTER_IS_RUNNING = "Forecaster is already running"; + + // ====================================== + // Job + // ====================================== + public static String FAIL_TO_START_FORECASTER = "Fail to start forecaster"; + public static String FAIL_TO_STOP_FORECASTER = "Fail to stop forecaster"; } diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java new file mode 100644 index 000000000..664701d7a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastCheckpointDao.java @@ -0,0 +1,463 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.ClientUtil; + +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; +import com.google.gson.Gson; + +import io.protostuff.LinkedBuffer; +import io.protostuff.ProtostuffIOUtil; +import io.protostuff.Schema; + +/** + * The ForecastCheckpointDao class implements all the functionality required for fetching, updating and + * removing forecast checkpoints. + * + */ +public class ForecastCheckpointDao extends CheckpointDao { + private static final Logger logger = LogManager.getLogger(ForecastCheckpointDao.class); + static final String LAST_PROCESSED_SAMPLE_FIELD = "last_processed_sample"; + + static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of forecaster"; + + RCFCasterMapper mapper; + private Schema rcfCasterSchema; + + public ForecastCheckpointDao( + Client client, + ClientUtil clientUtil, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + ForecastIndexManagement indexUtil, + RCFCasterMapper mapper, + Schema rcfCasterSchema, + Clock clock + ) { + super( + client, + clientUtil, + ForecastIndex.CHECKPOINT.getIndexName(), + gson, + maxCheckpointBytes, + serializeRCFBufferPool, + serializeRCFBufferSize, + indexUtil, + clock + ); + this.mapper = mapper; + this.rcfCasterSchema = rcfCasterSchema; + } + + /** + * Puts a RCFCaster model checkpoint in the storage. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param caster the RCFCaster model + * @param listener onResponse is called with null when the operation is completed + */ + public void putCasterCheckpoint(String modelId, RCFCaster caster, ActionListener listener) { + Map source = new HashMap<>(); + Optional modelCheckpoint = toCheckpoint(Optional.of(caster)); + if (!modelCheckpoint.isEmpty()) { + source.put(CommonName.FIELD_MODEL, modelCheckpoint.get()); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + putModelCheckpoint(modelId, source, listener); + } else { + listener.onFailure(new RuntimeException("Fail to create checkpoint to save")); + } + } + + private Optional toCheckpoint(Optional caster) { + if (caster.isEmpty()) { + return Optional.empty(); + } + Optional checkpoint = null; + Map.Entry result = checkoutOrNewBuffer(); + LinkedBuffer buffer = result.getKey(); + boolean needCheckin = result.getValue(); + try { + checkpoint = toCheckpoint(caster, buffer); + } catch (Exception e) { + logger.error("Failed to serialize model", e); + if (needCheckin) { + try { + serializeRCFBufferPool.invalidateObject(buffer); + needCheckin = false; + } catch (Exception x) { + logger.warn("Failed to invalidate buffer", x); + } + try { + checkpoint = toCheckpoint(caster, LinkedBuffer.allocate(serializeRCFBufferSize)); + } catch (Exception ex) { + logger.warn("Failed to generate checkpoint", ex); + } + } + } finally { + if (needCheckin) { + try { + serializeRCFBufferPool.returnObject(buffer); + } catch (Exception e) { + logger.warn("Failed to return buffer to pool", e); + } + } + } + return checkpoint; + } + + private Optional toCheckpoint(Optional caster, LinkedBuffer buffer) { + if (caster.isEmpty()) { + return Optional.empty(); + } + try { + byte[] bytes = AccessController.doPrivileged((PrivilegedAction) () -> { + RCFCasterState casterState = mapper.toState(caster.get()); + return ProtostuffIOUtil.toByteArray(casterState, rcfCasterSchema, buffer); + }); + return Optional.of(bytes); + } finally { + buffer.clear(); + } + } + + /** + * Prepare for index request using the contents of the given model state. Used in HC forecasting. + * @param modelState an entity model state + * @return serialized JSON map or empty map if the state is too bloated + * @throws IOException when serialization fails + */ + @Override + public Map toIndexSource(ModelState modelState) throws IOException { + Map source = new HashMap<>(); + Optional model = modelState.getModel(); + + Optional serializedModel = toCheckpoint(model); + if (serializedModel.isPresent() && serializedModel.get().length <= maxCheckpointBytes) { + // we cannot pass Optional as OpenSearch does not know how to serialize an Optional value + source.put(CommonName.FIELD_MODEL, serializedModel.get()); + } else { + logger + .warn( + new ParameterizedMessage( + "[{}]'s model is empty or too large: [{}] bytes", + modelState.getModelId(), + serializedModel.isPresent() ? serializedModel.get().length : 0 + ) + ); + } + if (modelState.getSamples() != null && !(modelState.getSamples().isEmpty())) { + source.put(CommonName.ENTITY_SAMPLE_QUEUE, toCheckpoint(modelState.getSamples()).get()); + } + // if there are no samples and no model, no need to index as other information are meta data + if (!source.containsKey(CommonName.ENTITY_SAMPLE_QUEUE) && !source.containsKey(CommonName.FIELD_MODEL)) { + return source; + } + + source.put(ForecastCommonName.FORECASTER_ID_KEY, modelState.getConfigId()); + if (modelState.getLastProcessedSample() != null) { + source.put(LAST_PROCESSED_SAMPLE_FIELD, modelState.getLastProcessedSample()); + } + source.put(CommonName.TIMESTAMP, ZonedDateTime.now(ZoneOffset.UTC)); + source.put(CommonName.SCHEMA_VERSION_FIELD, indexUtil.getSchemaVersion(ForecastIndex.CHECKPOINT)); + + Optional entity = modelState.getEntity(); + if (entity.isPresent()) { + source.put(CommonName.ENTITY_KEY, entity.get()); + } + + return source; + } + + private void deserializeRCFCasterModel(GetResponse response, String rcfModelId, ActionListener> listener) { + Object model = null; + if (response.isExists()) { + try { + model = response.getSource().get(CommonName.FIELD_MODEL); + listener.onResponse(Optional.ofNullable(toRCFCaster((byte[]) model))); + + } catch (Exception e) { + logger.error(new ParameterizedMessage("Unexpected error when deserializing [{}]", rcfModelId), e); + listener.onResponse(Optional.empty()); + } + } else { + listener.onResponse(Optional.empty()); + } + } + + RCFCaster toRCFCaster(byte[] bytes) { + RCFCaster rcfCaster = null; + if (bytes != null && bytes.length > 0) { + try { + RCFCasterState state = rcfCasterSchema.newMessage(); + AccessController.doPrivileged((PrivilegedAction) () -> { + ProtostuffIOUtil.mergeFrom(bytes, state, rcfCasterSchema); + return null; + }); + rcfCaster = mapper.toModel(state); + } catch (RuntimeException e) { + logger.error("Failed to deserialize RCFCaster model", e); + } + } + return rcfCaster; + } + + /** + * Returns to listener the checkpoint for the RCFCaster model. Used in single-stream forecasting. + * + * @param modelId id of the model + * @param listener onResponse is called with the model checkpoint, or empty for no such model + */ + public void getCasterModel(String modelId, ActionListener> listener) { + clientUtil + .asyncRequest( + new GetRequest(indexName, modelId), + client::get, + ActionListener + .wrap( + response -> deserializeRCFCasterModel(response, modelId, listener), + exception -> { + // expected exception, don't print stack trace + if (exception instanceof IndexNotFoundException) { + listener.onResponse(Optional.empty()); + } else { + listener.onFailure(exception); + } + } + ) + ); + } + + /** + * Load json checkpoint into models. Used in HC forecasting. + * + * @param checkpoint json checkpoint contents + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time; or empty if + * the raw checkpoint is too large + */ + @Override + protected ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId) { + try { + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + Entity entity = null; + Object serializedEntity = checkpoint.get(CommonName.ENTITY_KEY); + if (serializedEntity != null) { + try { + entity = Entity.fromJsonArray(serializedEntity); + } catch (Exception e) { + logger.error(new ParameterizedMessage("fail to parse entity", serializedEntity), e); + } + } + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + loadLastProcessedSample(checkpoint, modelId), + Optional.ofNullable(entity), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } catch (Exception e) { + logger.warn("Exception while deserializing checkpoint " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + + /** + * Delete checkpoints associated with a forecaster. Used in HC forecaster. + * @param forecasterId Forecaster Id + */ + public void deleteModelCheckpointByForecasterId(String forecasterId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + logger.info("Delete checkpoints of forecaster {}", forecasterId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, forecasterId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", forecasterId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + @Override + protected DeleteByQueryRequest createDeleteCheckpointRequest(String configId) { + return new DeleteByQueryRequest(indexName) + .setQuery(new MatchQueryBuilder(ForecastCommonName.FORECASTER_ID_KEY, configId)) + .setIndicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN) + .setAbortOnVersionConflict(false) // when current delete happens, previous might not finish. + // Retry in this case + .setRequestsPerSecond(500); // throttle delete requests + } + + @Override + protected ModelState fromSingleStreamModelCheckpoint(Map checkpoint, String modelId, String configId) { + + return AccessController.doPrivileged((PrivilegedAction>) () -> { + + RCFCaster rcfCaster = loadRCFCaster(checkpoint, modelId); + + ModelState modelState = new ModelState( + rcfCaster, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + loadLastProcessedSample(checkpoint, modelId), + Optional.empty(), + loadSampleQueue(checkpoint, modelId) + ); + + modelState.setLastCheckpointTime(loadTimestamp(checkpoint, modelId)); + + return modelState; + }); + } + + private RCFCaster loadRCFCaster(Map checkpoint, String modelId) { + byte[] model = (byte[]) checkpoint.get(CommonName.FIELD_MODEL); + if (model == null || model.length > maxCheckpointBytes) { + logger.warn(new ParameterizedMessage("[{}]'s model too large: [{}] bytes", modelId, model.length)); + return null; + } + return toRCFCaster(model); + } + + private Sample loadLastProcessedSample(Map checkpoint, String modelId) { + String lastProcessedSampleStr = (String) checkpoint.get(LAST_PROCESSED_SAMPLE_FIELD); + if (lastProcessedSampleStr == null) { + return null; + } + + try { + return Sample + .parse( + JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, lastProcessedSampleStr) + ); + } catch (Exception e) { + logger.warn("Exception while deserializing last processed sample for " + modelId, e); + // checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return null; + } + } + + private Instant loadTimestamp(Map checkpoint, String modelId) { + String lastCheckpointTimeString = (String) (checkpoint.get(CommonName.TIMESTAMP)); + return Instant.parse(lastCheckpointTimeString); + } + + private Deque loadSampleQueue(Map checkpoint, String modelId) { + Deque sampleQueue = new ArrayDeque<>(); + Object samples = checkpoint.get(CommonName.ENTITY_SAMPLE_QUEUE); + if (samples != null) { + try ( + XContentParser sampleParser = JsonXContent.jsonXContent + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.IGNORE_DEPRECATIONS, (String) samples) + ) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, sampleParser.currentToken(), sampleParser); + while (sampleParser.nextToken() != XContentParser.Token.END_ARRAY) { + sampleQueue.add(Sample.parse(sampleParser)); + } + } catch (Exception e) { + logger.warn("Exception while deserializing samples for " + modelId, e); + + return null; + } + } + // can be null when checkpoint corrupted (e.g., a checkpoint not recognized by current code + // due to bugs). Better redo training. + return sampleQueue; + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java new file mode 100644 index 000000000..46cc6d50b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastColdStart.java @@ -0,0 +1,155 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.config.Precision; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.calibration.Calibration; + +public class ForecastColdStart extends + ModelColdStart { + private double transformDecay; + + public ForecastColdStart( + Clock clock, + ThreadPool threadPool, + NodeStateManager nodeStateManager, + int rcfSampleSize, + int numberOfTrees, + double rcfTimeDecay, + int numMinSamples, + SearchFeatureDao searchFeatureDao, + double thresholdMinPvalue, + FeatureManager featureManager, + Duration modelTtl, + ForecastCheckpointWriteWorker checkpointWriteWorker, + int coolDownMinutes, + long rcfSeed, + double transformDecay, + int defaultTrainSamples, + int maxRoundofColdStart + ) { + // 1 means we sample all real data if possible + super( + modelTtl, + coolDownMinutes, + clock, + threadPool, + numMinSamples, + checkpointWriteWorker, + rcfSeed, + numberOfTrees, + rcfSampleSize, + thresholdMinPvalue, + rcfTimeDecay, + nodeStateManager, + 1, + defaultTrainSamples, + searchFeatureDao, + featureManager, + maxRoundofColdStart, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + AnalysisType.FORECAST + ); + this.transformDecay = transformDecay; + } + + // we deem type conversion safe and thus suppress warnings + @Override + protected void trainModelFromDataSegments( + Pair pointSamplePair, + Optional entity, + ModelState entityState, + Config config + ) { + double[][] dataPoints = pointSamplePair.getKey(); + if (dataPoints == null || dataPoints.length == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + + double[] firstPoint = dataPoints[0]; + if (firstPoint == null || firstPoint.length == 0) { + throw new IllegalArgumentException("Data points must not be empty."); + } + + int shingleSize = config.getShingleSize(); + int forecastHorizon = ((Forecaster) config).getHorizon(); + int dimensions = firstPoint.length * shingleSize; + + RCFCaster.Builder casterBuilder = RCFCaster + .builder() + .dimensions(dimensions) + .numberOfTrees(numberOfTrees) + .shingleSize(shingleSize) + .sampleSize(rcfSampleSize) + .internalShinglingEnabled(true) + .precision(Precision.FLOAT_32) + .anomalyRate(1 - this.thresholdMinPvalue) + .outputAfter(numMinSamples) + .calibration(Calibration.MINIMAL) + .timeDecay(rcfTimeDecay) + .parallelExecutionEnabled(false) + .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) + // the following affects the moving average in many of the transformations + // the 0.02 corresponds to a half life of 1/0.02 = 50 observations + // this is different from the timeDecay() of RCF; however it is a similar + // concept + .transformDecay(transformDecay) + .forecastHorizon(forecastHorizon) + .initialAcceptFraction(initialAcceptFraction); + + if (rcfSeed > 0) { + casterBuilder.randomSeed(rcfSeed); + } + + RCFCaster caster = casterBuilder.build(); + + for (int i = 0; i < dataPoints.length; i++) { + caster.process(dataPoints[i], 0); + } + + entityState.setModel(caster); + entityState.setLastUsedTime(clock.instant()); + entityState.setLastProcessedSample(pointSamplePair.getValue()); + + // save to checkpoint + checkpointWriteWorker.write(entityState, true, RequestPriority.MEDIUM); + } + + @Override + protected boolean isInterpolationInColdStartEnabled() { + return false; + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java new file mode 100644 index 000000000..20e06521d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/ForecastModelManager.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Clock; +import java.util.Locale; + +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelManager; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ForecastDescriptor; +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastModelManager extends + ModelManager { + + public ForecastModelManager( + ForecastCheckpointDao checkpointDao, + Clock clock, + int rcfNumTrees, + int rcfNumSamplesInTree, + double rcfTimeDecay, + int rcfNumMinSamples, + ForecastColdStart entityColdStarter, + MemoryTracker memoryTracker, + FeatureManager featureManager + ) { + super( + rcfNumTrees, + rcfNumSamplesInTree, + rcfTimeDecay, + rcfNumMinSamples, + entityColdStarter, + memoryTracker, + clock, + featureManager, + checkpointDao + ); + } + + @Override + protected RCFCasterResult createEmptyResult() { + return new RCFCasterResult(null, 0, 0, 0); + } + + @Override + protected RCFCasterResult toResult(RandomCutForest forecast, RCFDescriptor castDescriptor) { + if (castDescriptor instanceof ForecastDescriptor) { + ForecastDescriptor forecastDescriptor = (ForecastDescriptor) castDescriptor; + // Use forecastDescriptor in the rest of your method + return new RCFCasterResult( + forecastDescriptor.getTimedForecast().rangeVector, + forecastDescriptor.getDataConfidence(), + forecast.getTotalUpdates(), + forecastDescriptor.getRCFScore() + ); + } else { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported type of AnomalyDescriptor : %s", castDescriptor)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java new file mode 100644 index 000000000..3584c7203 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ml/RCFCasterResult.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ml; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; + +import com.amazon.randomcutforest.returntypes.RangeVector; + +public class RCFCasterResult extends IntermediateResult { + private final RangeVector forecast; + private final double dataQuality; + + public RCFCasterResult(RangeVector forecast, double dataQuality, long totalUpdates, double rcfScore) { + super(totalUpdates, rcfScore); + this.forecast = forecast; + this.dataQuality = dataQuality; + } + + public RangeVector getForecast() { + return forecast; + } + + public double getDataQuality() { + return dataQuality; + } + + @Override + public List toIndexableResults( + Config forecaster, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + List featureData, + Optional entity, + Integer schemaVersion, + String modelId, + String taskId, + String error + ) { + if (forecast.values == null || forecast.values.length == 0) { + return Collections.emptyList(); + } + return ForecastResult + .fromRawRCFCasterResult( + forecaster.getId(), + forecaster.getIntervalInMilliseconds(), + dataQuality, + featureData, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + entity, + forecaster.getUser(), + schemaVersion, + modelId, + forecast.values, + forecast.upper, + forecast.lower, + taskId + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastResult.java b/src/main/java/org/opensearch/forecast/model/ForecastResult.java index 1ce75ff63..fc3855a4e 100644 --- a/src/main/java/org/opensearch/forecast/model/ForecastResult.java +++ b/src/main/java/org/opensearch/forecast/model/ForecastResult.java @@ -175,9 +175,13 @@ public static List fromRawRCFCasterResult( String taskId ) { int inputLength = featureData.size(); - int numberOfForecasts = forecastsValues.length / inputLength; + int numberOfForecasts = 0; + if (forecastsValues != null) { + numberOfForecasts = forecastsValues.length / inputLength; + } - List convertedForecastValues = new ArrayList<>(numberOfForecasts); + // +1 for actual value + List convertedForecastValues = new ArrayList<>(numberOfForecasts + 1); // store feature data and forecast value separately for easy query on feature data // we can join them using forecasterId, entityId, and executionStartTime/executionEndTime diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTask.java b/src/main/java/org/opensearch/forecast/model/ForecastTask.java new file mode 100644 index 000000000..f89e82665 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTask.java @@ -0,0 +1,404 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +public class ForecastTask extends TimeSeriesTask { + public static final String FORECASTER_ID_FIELD = "forecaster_id"; + public static final String FORECASTER_FIELD = "forecaster"; + public static final String DATE_RANGE_FIELD = "date_range"; + + private Forecaster forecaster = null; + private DateRange dateRange = null; + + private ForecastTask() {} + + public ForecastTask(StreamInput input) throws IOException { + this.taskId = input.readOptionalString(); + this.taskType = input.readOptionalString(); + this.configId = input.readOptionalString(); + if (input.readBoolean()) { + this.forecaster = new Forecaster(input); + } else { + this.forecaster = null; + } + this.state = input.readOptionalString(); + this.taskProgress = input.readOptionalFloat(); + this.initProgress = input.readOptionalFloat(); + this.currentPiece = input.readOptionalInstant(); + this.executionStartTime = input.readOptionalInstant(); + this.executionEndTime = input.readOptionalInstant(); + this.isLatest = input.readOptionalBoolean(); + this.error = input.readOptionalString(); + this.checkpointId = input.readOptionalString(); + this.lastUpdateTime = input.readOptionalInstant(); + this.startedBy = input.readOptionalString(); + this.stoppedBy = input.readOptionalString(); + this.coordinatingNode = input.readOptionalString(); + this.workerNode = input.readOptionalString(); + if (input.readBoolean()) { + this.user = new User(input); + } else { + user = null; + } + // Below are new fields added since AD 1.1 + if (input.available() > 0) { + if (input.readBoolean()) { + this.dateRange = new DateRange(input); + } else { + this.dateRange = null; + } + if (input.readBoolean()) { + this.entity = new Entity(input); + } else { + this.entity = null; + } + this.parentTaskId = input.readOptionalString(); + this.estimatedMinutesLeft = input.readOptionalInt(); + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(taskId); + out.writeOptionalString(taskType); + out.writeOptionalString(configId); + if (forecaster != null) { + out.writeBoolean(true); + forecaster.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(state); + out.writeOptionalFloat(taskProgress); + out.writeOptionalFloat(initProgress); + out.writeOptionalInstant(currentPiece); + out.writeOptionalInstant(executionStartTime); + out.writeOptionalInstant(executionEndTime); + out.writeOptionalBoolean(isLatest); + out.writeOptionalString(error); + out.writeOptionalString(checkpointId); + out.writeOptionalInstant(lastUpdateTime); + out.writeOptionalString(startedBy); + out.writeOptionalString(stoppedBy); + out.writeOptionalString(coordinatingNode); + out.writeOptionalString(workerNode); + if (user != null) { + out.writeBoolean(true); // user exists + user.writeTo(out); + } else { + out.writeBoolean(false); // user does not exist + } + // Only forward forecast task to nodes with same version, so it's ok to write these new fields. + if (dateRange != null) { + out.writeBoolean(true); + dateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + if (entity != null) { + out.writeBoolean(true); + entity.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeOptionalString(parentTaskId); + out.writeOptionalInt(estimatedMinutesLeft); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean isEntityTask() { + return ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY.name().equals(taskType); + } + + public static class Builder extends TimeSeriesTask.Builder { + private Forecaster forecaster = null; + private DateRange dateRange = null; + + public Builder() {} + + public Builder forecaster(Forecaster forecaster) { + this.forecaster = forecaster; + return this; + } + + public Builder dateRange(DateRange dateRange) { + this.dateRange = dateRange; + return this; + } + + public ForecastTask build() { + ForecastTask forecastTask = new ForecastTask(); + forecastTask.taskId = this.taskId; + forecastTask.lastUpdateTime = this.lastUpdateTime; + forecastTask.error = this.error; + forecastTask.state = this.state; + forecastTask.configId = this.configId; + forecastTask.taskProgress = this.taskProgress; + forecastTask.initProgress = this.initProgress; + forecastTask.currentPiece = this.currentPiece; + forecastTask.executionStartTime = this.executionStartTime; + forecastTask.executionEndTime = this.executionEndTime; + forecastTask.isLatest = this.isLatest; + forecastTask.taskType = this.taskType; + forecastTask.checkpointId = this.checkpointId; + forecastTask.forecaster = this.forecaster; + forecastTask.startedBy = this.startedBy; + forecastTask.stoppedBy = this.stoppedBy; + forecastTask.coordinatingNode = this.coordinatingNode; + forecastTask.workerNode = this.workerNode; + forecastTask.dateRange = this.dateRange; + forecastTask.entity = this.entity; + forecastTask.parentTaskId = this.parentTaskId; + forecastTask.estimatedMinutesLeft = this.estimatedMinutesLeft; + forecastTask.user = this.user; + + return forecastTask; + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + xContentBuilder = super.toXContent(xContentBuilder, params); + if (configId != null) { + xContentBuilder.field(FORECASTER_ID_FIELD, configId); + } + if (forecaster != null) { + xContentBuilder.field(FORECASTER_FIELD, forecaster); + } + if (dateRange != null) { + xContentBuilder.field(DATE_RANGE_FIELD, dateRange); + } + return xContentBuilder.endObject(); + } + + public static ForecastTask parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ForecastTask parse(XContentParser parser, String taskId) throws IOException { + Instant lastUpdateTime = null; + String startedBy = null; + String stoppedBy = null; + String error = null; + String state = null; + String configId = null; + Float taskProgress = null; + Float initProgress = null; + Instant currentPiece = null; + Instant executionStartTime = null; + Instant executionEndTime = null; + Boolean isLatest = null; + String taskType = null; + String checkpointId = null; + Forecaster forecaster = null; + String parsedTaskId = taskId; + String coordinatingNode = null; + String workerNode = null; + DateRange dateRange = null; + Entity entity = null; + String parentTaskId = null; + Integer estimatedMinutesLeft = null; + User user = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case LAST_UPDATE_TIME_FIELD: + lastUpdateTime = ParseUtils.toInstant(parser); + break; + case STARTED_BY_FIELD: + startedBy = parser.text(); + break; + case STOPPED_BY_FIELD: + stoppedBy = parser.text(); + break; + case ERROR_FIELD: + error = parser.text(); + break; + case STATE_FIELD: + state = parser.text(); + break; + case FORECASTER_ID_FIELD: + configId = parser.text(); + break; + case TASK_PROGRESS_FIELD: + taskProgress = parser.floatValue(); + break; + case INIT_PROGRESS_FIELD: + initProgress = parser.floatValue(); + break; + case CURRENT_PIECE_FIELD: + currentPiece = ParseUtils.toInstant(parser); + break; + case EXECUTION_START_TIME_FIELD: + executionStartTime = ParseUtils.toInstant(parser); + break; + case EXECUTION_END_TIME_FIELD: + executionEndTime = ParseUtils.toInstant(parser); + break; + case IS_LATEST_FIELD: + isLatest = parser.booleanValue(); + break; + case TASK_TYPE_FIELD: + taskType = parser.text(); + break; + case CHECKPOINT_ID_FIELD: + checkpointId = parser.text(); + break; + case FORECASTER_FIELD: + forecaster = Forecaster.parse(parser); + break; + case TASK_ID_FIELD: + parsedTaskId = parser.text(); + break; + case COORDINATING_NODE_FIELD: + coordinatingNode = parser.text(); + break; + case WORKER_NODE_FIELD: + workerNode = parser.text(); + break; + case DATE_RANGE_FIELD: + dateRange = DateRange.parse(parser); + break; + case ENTITY_FIELD: + entity = Entity.parse(parser); + break; + case PARENT_TASK_ID_FIELD: + parentTaskId = parser.text(); + break; + case ESTIMATED_MINUTES_LEFT_FIELD: + estimatedMinutesLeft = parser.intValue(); + break; + case USER_FIELD: + user = User.parse(parser); + break; + default: + parser.skipChildren(); + break; + } + } + Forecaster copyForecaster = forecaster == null + ? null + : new Forecaster( + configId, + forecaster.getVersion(), + forecaster.getName(), + forecaster.getDescription(), + forecaster.getTimeField(), + forecaster.getIndices(), + forecaster.getFeatureAttributes(), + forecaster.getFilterQuery(), + forecaster.getInterval(), + forecaster.getWindowDelay(), + forecaster.getShingleSize(), + forecaster.getUiMetadata(), + forecaster.getSchemaVersion(), + forecaster.getLastUpdateTime(), + forecaster.getCategoryFields(), + forecaster.getUser(), + forecaster.getCustomResultIndex(), + forecaster.getHorizon(), + forecaster.getImputationOption() + ); + return new Builder() + .taskId(parsedTaskId) + .lastUpdateTime(lastUpdateTime) + .startedBy(startedBy) + .stoppedBy(stoppedBy) + .error(error) + .state(state) + .configId(configId) + .taskProgress(taskProgress) + .initProgress(initProgress) + .currentPiece(currentPiece) + .executionStartTime(executionStartTime) + .executionEndTime(executionEndTime) + .isLatest(isLatest) + .taskType(taskType) + .checkpointId(checkpointId) + .coordinatingNode(coordinatingNode) + .workerNode(workerNode) + .forecaster(copyForecaster) + .dateRange(dateRange) + .entity(entity) + .parentTaskId(parentTaskId) + .estimatedMinutesLeft(estimatedMinutesLeft) + .user(user) + .build(); + } + + @Generated + @Override + public boolean equals(Object other) { + if (this == other) + return true; + if (other == null || getClass() != other.getClass()) + return false; + ForecastTask that = (ForecastTask) other; + return super.equals(that) + && Objects.equal(getConfigId(), that.getConfigId()) + && Objects.equal(getForecaster(), that.getForecaster()) + && Objects.equal(getDateRange(), that.getDateRange()); + } + + @Generated + @Override + public int hashCode() { + int superHashCode = super.hashCode(); + int hash = Objects.hashCode(configId, forecaster, dateRange); + hash += 89 * superHashCode; + return hash; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public DateRange getDateRange() { + return dateRange; + } + + @Override + public String getEntityModelId() { + return entity == null ? null : entity.getModelId(configId).orElse(null); + } + + public void setDateRange(DateRange dateRange) { + this.dateRange = dateRange; + } +} diff --git a/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java new file mode 100644 index 000000000..db84849a0 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/model/ForecastTaskType.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.model; + +import java.util.List; + +import org.opensearch.timeseries.model.TaskType; + +import com.google.common.collect.ImmutableList; + +public enum ForecastTaskType implements TaskType { + FORECAST_REALTIME_SINGLE_STREAM, + FORECAST_REALTIME_HC_FORECASTER, + FORECAST_HISTORICAL_SINGLE_STREAM, + // forecaster level task to track overall state, init progress, error etc. for HC forecaster + FORECAST_HISTORICAL_HC_FORECASTER, + // entity level task to track just one specific entity's state, init progress, error etc. + FORECAST_HISTORICAL_HC_ENTITY; + + public static List HISTORICAL_FORECASTER_TASK_TYPES = ImmutableList + .of(ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM); + public static List ALL_HISTORICAL_TASK_TYPES = ImmutableList + .of( + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ); + public static List REALTIME_TASK_TYPES = ImmutableList + .of(ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER); + public static List ALL_FORECAST_TASK_TYPES = ImmutableList + .of( + ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM, + ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM, + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER, + ForecastTaskType.FORECAST_HISTORICAL_HC_ENTITY + ); +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java new file mode 100644 index 000000000..a7b701cc2 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointMaintainWorker.java @@ -0,0 +1,90 @@ +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointMaintainWorker extends + CheckpointMaintainWorker { + public static final String WORKER_NAME = "forecast-checkpoint-maintain"; + + public ForecastCheckpointMaintainWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + converter, + AnalysisType.FORECAST + ); + + this.batchSize = FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS + .get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer( + FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + it -> this.expectedExecutionTimeInMilliSecsPerRequest = it + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java new file mode 100644 index 000000000..f6f7f47d8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointReadWorker.java @@ -0,0 +1,144 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointReadWorker extends + CheckpointReadWorker { + public static final String WORKER_NAME = "forecast-checkpoint-read"; + + public ForecastCheckpointReadWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastModelManager modelManager, + ForecastCheckpointDao checkpointDao, + ForecastColdStartWorker entityColdStartQueue, + ForecastResultWriteWorker resultWriteQueue, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + Provider cacheProvider, + Duration stateTtl, + ForecastCheckpointWriteWorker checkpointWriteQueue, + Stats adStats + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + executionTtl, + modelManager, + checkpointDao, + entityColdStartQueue, + resultWriteQueue, + stateManager, + indexUtil, + cacheProvider, + stateTtl, + checkpointWriteQueue, + adStats, + FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastCommonName.FORECAST_CHECKPOINT_INDEX_NAME, + StatNames.FORECAST_MODEL_CORRUTPION_COUNT, + AnalysisType.FORECAST + ); + } + + @Override + protected void saveResult(RCFCasterResult result, Config config, FeatureRequest origRequest, Optional entity, String modelId) { + if (result != null && result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()), + Instant.now(), + Instant.now(), + ParseUtils.getFeatureData(origRequest.getCurrentFeature(), config), + entity, + indexUtil.getSchemaVersion(ForecastIndex.RESULT), + modelId, + null, + null + ); + + for (ForecastResult r : indexableResults) { + resultWriteWorker + .put( + new ForecastResultWriteRequest( + origRequest.getExpirationEpochMs(), + config.getId(), + RequestPriority.MEDIUM, + r, + config.getCustomResultIndex() + ) + ); + } + } + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java new file mode 100644 index 000000000..42a72e844 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastCheckpointWriteWorker.java @@ -0,0 +1,76 @@ +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastCheckpointWriteWorker extends + CheckpointWriteWorker { + public static final String WORKER_NAME = "forecast-checkpoint-write"; + + public ForecastCheckpointWriteWorker( + long heapSize, + int singleRequestSize, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastCheckpointDao checkpoint, + String indexName, + Duration checkpointInterval, + NodeStateManager timeSeriesNodeStateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSize, + singleRequestSize, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + timeSeriesNodeStateManager, + checkpoint, + indexName, + checkpointInterval, + AnalysisType.FORECAST + ); + } + +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java new file mode 100644 index 000000000..0a2711a6e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdEntityWorker.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS; + +import java.time.Clock; +import java.time.Duration; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * A queue slowly releasing low-priority requests to CheckpointReadQueue + * + * ColdEntityQueue is a queue to absorb cold entities. Like hot entities, we load a cold + * entity's model checkpoint from disk, train models if the checkpoint is not found, + * query for missed features to complete a shingle, use the models to check whether + * the incoming feature is normal, update models, and save the detection results to disks.  + * Implementation-wise, we reuse the queues we have developed for hot entities. + * The differences are: we process hot entities as long as resources (e.g., AD + * thread pool has availability) are available, while we release cold entity requests + * to other queues at a slow controlled pace. Also, cold entity requests' priority is low. + * So only when there are no hot entity requests to process are we going to process cold + * entity requests.  + * + */ +public class ForecastColdEntityWorker extends + ColdEntityWorker { + public static final String WORKER_NAME = "forecast-cold-entity"; + + public ForecastColdEntityWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService forecastCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + ForecastCheckpointReadWorker checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + forecastCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnalysisType.FORECAST + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java new file mode 100644 index 000000000..6cceb79cd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastColdStartWorker.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastColdStartWorker extends + ColdStartWorker { + public static final String WORKER_NAME = "forecast-hc-cold-start"; + + public ForecastColdStartWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService circuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastColdStart coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + ForecastPriorityCache cacheProvider + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_COLD_START_QUEUE_CONCURRENCY, + executionTtl, + coldStarter, + stateTtl, + nodeStateManager, + cacheProvider, + AnalysisType.FORECAST + ); + } + + @Override + protected ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId) { + return new ModelState( + null, + modelId, + configId, + ModelManager.ModelType.RCFCASTER.getName(), + clock, + 0, + new Sample(), + coldStartRequest.getEntity(), + new ArrayDeque<>() + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java new file mode 100644 index 000000000..f9fd07c25 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteRequest.java @@ -0,0 +1,36 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ForecastResultWriteRequest extends ResultWriteRequest { + + public ForecastResultWriteRequest( + long expirationEpochMs, + String detectorId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + super(expirationEpochMs, detectorId, priority, result, resultIndex); + } + + public ForecastResultWriteRequest(StreamInput in) throws IOException { + super(in, ForecastResult::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java new file mode 100644 index 000000000..ed7684c62 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/ratelimit/ForecastResultWriteWorker.java @@ -0,0 +1,108 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.ratelimit; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; + +public class ForecastResultWriteWorker extends + ResultWriteWorker { + public static final String WORKER_NAME = "forecast-result-write"; + + public ForecastResultWriteWorker( + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Duration executionTtl, + ForecastIndexMemoryPressureAwareResultHandler resultHandler, + NamedXContentRegistry xContentRegistry, + NodeStateManager stateManager, + Duration stateTtl + ) { + super( + WORKER_NAME, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + executionTtl, + FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + stateTtl, + stateManager, + resultHandler, + xContentRegistry, + ForecastResult::parse, + AnalysisType.FORECAST + ); + } + + @Override + protected ForecastResultBulkRequest toBatchRequest(List toProcess) { + final ForecastResultBulkRequest bulkRequest = new ForecastResultBulkRequest(); + for (ForecastResultWriteRequest request : toProcess) { + bulkRequest.add(request); + } + return bulkRequest; + } + + @Override + protected ForecastResultWriteRequest createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ForecastResult result, + String resultIndex + ) { + return new ForecastResultWriteRequest(expirationEpochMs, configId, priority, result, resultIndex); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java new file mode 100644 index 000000000..ccefb7d04 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/AbstractForecasterAction.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_WINDOW_DELAY; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_FORECAST_FEATURES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_HC_FORECASTERS; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.forecast.settings.ForecastNumericSetting; +import org.opensearch.rest.BaseRestHandler; + +public abstract class AbstractForecasterAction extends BaseRestHandler { + protected volatile TimeValue requestTimeout; + protected volatile TimeValue forecastInterval; + protected volatile TimeValue forecastWindowDelay; + protected volatile Integer maxSingleStreamForecasters; + protected volatile Integer maxHCForecasters; + protected volatile Integer maxForecastFeatures; + protected volatile Integer maxCategoricalFields; + + public AbstractForecasterAction(Settings settings, ClusterService clusterService) { + this.requestTimeout = FORECAST_REQUEST_TIMEOUT.get(settings); + this.forecastInterval = FORECAST_INTERVAL.get(settings); + this.forecastWindowDelay = FORECAST_WINDOW_DELAY.get(settings); + this.maxSingleStreamForecasters = MAX_SINGLE_STREAM_FORECASTERS.get(settings); + this.maxHCForecasters = MAX_HC_FORECASTERS.get(settings); + this.maxForecastFeatures = MAX_FORECAST_FEATURES; + this.maxCategoricalFields = ForecastNumericSetting.maxCategoricalFields(); + // TODO: will add more cluster setting consumer later + // TODO: inject ClusterSettings only if clusterService is only used to get ClusterSettings + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_REQUEST_TIMEOUT, it -> requestTimeout = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INTERVAL, it -> forecastInterval = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_WINDOW_DELAY, it -> forecastWindowDelay = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_SINGLE_STREAM_FORECASTERS, it -> maxSingleStreamForecasters = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_HC_FORECASTERS, it -> maxHCForecasters = it); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java new file mode 100644 index 000000000..9ba626fcd --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/ForecasterExecutionInput.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.annotation.Generated; +import org.opensearch.timeseries.util.ParseUtils; + +import com.google.common.base.Objects; + +/** + * Input data needed to trigger forecaster. + */ +public class ForecasterExecutionInput implements ToXContentObject { + + private static final String FORECASTER_ID_FIELD = "forecaster_id"; + private static final String PERIOD_START_FIELD = "period_start"; + private static final String PERIOD_END_FIELD = "period_end"; + private static final String FORECASTER_FIELD = "forecaster"; + private Instant periodStart; + private Instant periodEnd; + private String forecasterId; + private Forecaster forecaster; + + public ForecasterExecutionInput(String forecasterId, Instant periodStart, Instant periodEnd, Forecaster forecaster) { + this.periodStart = periodStart; + this.periodEnd = periodEnd; + this.forecasterId = forecasterId; + this.forecaster = forecaster; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder + .startObject() + .field(FORECASTER_ID_FIELD, forecasterId) + .field(PERIOD_START_FIELD, periodStart.toEpochMilli()) + .field(PERIOD_END_FIELD, periodEnd.toEpochMilli()) + .field(FORECASTER_FIELD, forecaster); + return xContentBuilder.endObject(); + } + + public static ForecasterExecutionInput parse(XContentParser parser) throws IOException { + return parse(parser, null); + } + + public static ForecasterExecutionInput parse(XContentParser parser, String inputConfigId) throws IOException { + Instant periodStart = null; + Instant periodEnd = null; + Forecaster forecaster = null; + String forecasterId = null; + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case FORECASTER_ID_FIELD: + forecasterId = parser.text(); + break; + case PERIOD_START_FIELD: + periodStart = ParseUtils.toInstant(parser); + break; + case PERIOD_END_FIELD: + periodEnd = ParseUtils.toInstant(parser); + break; + case FORECASTER_FIELD: + XContentParser.Token token = parser.currentToken(); + if (parser.currentToken().equals(XContentParser.Token.START_OBJECT)) { + forecaster = Forecaster.parse(parser, forecasterId); + } + break; + default: + break; + } + } + if (!Strings.isNullOrEmpty(inputConfigId)) { + forecasterId = inputConfigId; + } + return new ForecasterExecutionInput(forecasterId, periodStart, periodEnd, forecaster); + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ForecasterExecutionInput that = (ForecasterExecutionInput) o; + return Objects.equal(periodStart, that.periodStart) + && Objects.equal(periodEnd, that.periodEnd) + && Objects.equal(forecasterId, that.forecasterId) + && Objects.equal(forecaster, that.forecaster); + } + + @Generated + @Override + public int hashCode() { + return Objects.hashCode(periodStart, periodEnd, forecasterId); + } + + public Instant getPeriodStart() { + return periodStart; + } + + public Instant getPeriodEnd() { + return periodEnd; + } + + public String getForecasterId() { + return forecasterId; + } + + public void setForecasterId(String forecasterId) { + this.forecasterId = forecasterId; + } + + public Forecaster getForecaster() { + return forecaster; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestExecuteForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestExecuteForecasterAction.java new file mode 100644 index 000000000..5cc9b24cc --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestExecuteForecasterAction.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.RUN; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang.StringUtils; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to handle request to forecast. + */ +public class RestExecuteForecasterAction extends BaseRestHandler { + + public static final String FORECASTER_ACTION = "execute_forecaster"; + + public RestExecuteForecasterAction() {} + + @Override + public String getName() { + return FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + ForecasterExecutionInput input = getForecasterExecutionInput(request); + return channel -> { + String error = validateAdExecutionInput(input); + if (StringUtils.isNotBlank(error)) { + channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, error)); + return; + } + + ForecastResultRequest getRequest = new ForecastResultRequest( + input.getForecasterId(), + input.getPeriodStart().toEpochMilli(), + input.getPeriodEnd().toEpochMilli() + ); + client.execute(ForecastResultAction.INSTANCE, getRequest, new RestToXContentListener<>(channel)); + }; + } + + private ForecasterExecutionInput getForecasterExecutionInput(RestRequest request) throws IOException { + String forecasterId = null; + if (request.hasParam(FORECASTER_ID)) { + forecasterId = request.param(FORECASTER_ID); + } + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ForecasterExecutionInput input = ForecasterExecutionInput.parse(parser, forecasterId); + if (forecasterId != null) { + input.setForecasterId(forecasterId); + } + return input; + } + + private String validateAdExecutionInput(ForecasterExecutionInput input) { + if (StringUtils.isBlank(input.getForecasterId())) { + return "Must set forecaster id or detector"; + } + if (input.getPeriodStart() == null || input.getPeriodEnd() == null) { + return "Must set both period start and end date with epoch of milliseconds"; + } + if (!input.getPeriodStart().isBefore(input.getPeriodEnd())) { + return "Period start date should be before end date"; + } + return null; + } + + @Override + public List routes() { + return ImmutableList + .of( + // execute forester once + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, RUN) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java new file mode 100644 index 000000000..e13ce608f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestForecasterJobAction.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; +import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.RestJobAction; +import org.opensearch.timeseries.transport.JobRequest; + +import com.google.common.collect.ImmutableList; + +public class RestForecasterJobAction extends RestJobAction { + public static final String FORECAST_JOB_ACTION = "forecaster_job_action"; + + @Override + public String getName() { + return FORECAST_JOB_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + String forecasterId = request.param(FORECASTER_ID); + boolean historical = request.paramAsBoolean("historical", false); + String rawPath = request.rawPath(); + DateRange dateRange = parseInputDateRange(request); + + JobRequest forecasterJobRequest = new JobRequest(forecasterId, dateRange, historical, rawPath); + + return channel -> client.execute(ForecasterJobAction.INSTANCE, forecasterJobRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList + .of( + /// start forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, START_JOB) + ), + /// stop forecaster Job + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/{%s}/%s", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID, STOP_JOB) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java new file mode 100644 index 000000000..76ee470f8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestGetForecasterAction.java @@ -0,0 +1,141 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.TYPE; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestActions; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.RestHandlerUtils; + +import com.google.common.collect.ImmutableList; + +/** + * This class consists of the REST handler to retrieve an anomaly detector. + */ +public class RestGetForecasterAction extends BaseRestHandler { + + private static final String GET_FORECASTER_ACTION = "get_forecaster"; + + public RestGetForecasterAction() {} + + @Override + public String getName() { + return GET_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + String detectorId = request.param(FORECASTER_ID); + String typesStr = request.param(TYPE); + + String rawPath = request.rawPath(); + boolean returnJob = request.paramAsBoolean("job", false); + boolean returnTask = request.paramAsBoolean("task", false); + boolean all = request.paramAsBoolean("_all", false); + GetConfigRequest getForecasterRequest = new GetConfigRequest( + detectorId, + RestActions.parseVersion(request), + returnJob, + returnTask, + typesStr, + rawPath, + all, + RestHandlerUtils.buildEntity(request, detectorId) + ); + + return channel -> client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, new RestToXContentListener<>(channel)); + } + + @Override + public List routes() { + return ImmutableList + .of( + // Opensearch-only API. Considering users may provide entity in the search body, + // support POST as well. + + // profile API + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.POST, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE + ) + ), + // types is a profile names. See a complete list of supported profiles names in + // org.opensearch.ad.model.ProfileName. + new Route( + RestRequest.Method.GET, + String + .format( + Locale.ROOT, + "%s/{%s}/%s/{%s}", + TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, + FORECASTER_ID, + RestHandlerUtils.PROFILE, + TYPE + ) + ), + + // get forecaster API + new Route( + RestRequest.Method.GET, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java new file mode 100644 index 000000000..8605621f9 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/RestIndexForecasterAction.java @@ -0,0 +1,138 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.FORECASTER_ID; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_PRIMARY_TERM; +import static org.opensearch.timeseries.util.RestHandlerUtils.IF_SEQ_NO; +import static org.opensearch.timeseries.util.RestHandlerUtils.REFRESH; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterRequest; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Config; + +import com.google.common.collect.ImmutableList; + +/** + * Rest handlers to create and update forecaster. + */ +public class RestIndexForecasterAction extends AbstractForecasterAction { + private static final String INDEX_FORECASTER_ACTION = "index_forecaster_action"; + private final Logger logger = LogManager.getLogger(RestIndexForecasterAction.class); + + public RestIndexForecasterAction(Settings settings, ClusterService clusterService) { + super(settings, clusterService); + } + + @Override + public String getName() { + return INDEX_FORECASTER_ACTION; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new IllegalStateException(ForecastCommonMessages.DISABLED_ERR_MSG); + } + + String forecasterId = request.param(FORECASTER_ID, Config.NO_ID); + logger.info("Forecaster {} action for forecasterId {}", request.method(), forecasterId); + + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Forecaster forecaster = Forecaster.parse(parser, forecasterId, null, forecastInterval, forecastWindowDelay); + + long seqNo = request.paramAsLong(IF_SEQ_NO, SequenceNumbers.UNASSIGNED_SEQ_NO); + long primaryTerm = request.paramAsLong(IF_PRIMARY_TERM, SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + WriteRequest.RefreshPolicy refreshPolicy = request.hasParam(REFRESH) + ? WriteRequest.RefreshPolicy.parse(request.param(REFRESH)) + : WriteRequest.RefreshPolicy.IMMEDIATE; + RestRequest.Method method = request.getHttpRequest().method(); + + IndexForecasterRequest indexAnomalyDetectorRequest = new IndexForecasterRequest( + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + method, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields + ); + + return channel -> client + .execute(IndexForecasterAction.INSTANCE, indexAnomalyDetectorRequest, indexForecasterResponse(channel, method)); + } + + @Override + public List routes() { + return ImmutableList + .of( + // Create + new Route(RestRequest.Method.POST, TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI), + // Update + new Route( + RestRequest.Method.PUT, + String.format(Locale.ROOT, "%s/{%s}", TimeSeriesAnalyticsPlugin.FORECAST_FORECASTERS_URI, FORECASTER_ID) + ) + ); + } + + private RestResponseListener indexForecasterResponse(RestChannel channel, RestRequest.Method method) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(IndexForecasterResponse response) throws Exception { + RestStatus restStatus = RestStatus.CREATED; + if (method == RestRequest.Method.PUT) { + restStatus = RestStatus.OK; + } + BytesRestResponse bytesRestResponse = new BytesRestResponse( + restStatus, + response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS) + ); + if (restStatus == RestStatus.CREATED) { + String location = String.format(Locale.ROOT, "%s/%s", TimeSeriesAnalyticsPlugin.FORECAST_BASE_URI, response.getId()); + bytesRestResponse.addHeader("Location", location); + } + return bytesRestResponse; + } + }; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java new file mode 100644 index 000000000..60af3bb0c --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/AbstractForecasterActionHandler.java @@ -0,0 +1,236 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.rest.handler.AbstractTimeSeriesActionHandler; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class AbstractForecasterActionHandler extends + AbstractTimeSeriesActionHandler { + protected final Logger logger = LogManager.getLogger(AbstractForecasterActionHandler.class); + + public static final String EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG = "Can't create more than %d HC forecasters."; + public static final String EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG = "Can't create more than %d single-stream forecasters."; + public static final String NO_DOCS_IN_USER_INDEX_MSG = "Can't create forecasters as no document is found in the indices: "; + public static final String DUPLICATE_FORECASTER_MSG = "Cannot create forecasters with name [%s] as it's already used by forecaster %s"; + public static final String VALIDATION_FEATURE_FAILURE = "Validation failed for feature(s) of forecaster %s"; + + protected final Integer maxSingleStreamForecasters; + protected final Integer maxHCForecasters; + + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client ES node client that executes actions on the local node + * @param clientUtil Forecast security client + * @param transportService ES transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxFeatures max features allowed per forecaster + * @param maxCategoricalFields max categorical fields allowed + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + * @param clock clock object to know when to timeout + * @param isDryRun Whether handler is dryrun or not + */ + public AbstractForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + SearchFeatureDao searchFeatureDao, + String validationType, + boolean isDryRun + ) { + super( + forecaster, + forecastIndices, + isDryRun, + client, + forecasterId, + clientUtil, + user, + method, + clusterService, + xContentRegistry, + transportService, + requestTimeout, + refreshPolicy, + seqNo, + primaryTerm, + validationType, + searchFeatureDao, + maxFeatures, + maxCategoricalFields, + AnalysisType.FORECAST + ); + this.maxSingleStreamForecasters = maxSingleStreamForecasters; + this.maxHCForecasters = maxHCForecasters; + } + + @Override + protected TimeSeriesException createValidationException(String msg, ValidationIssueType type) { + return new ValidationException(msg, type, ValidationAspect.FORECASTER); + } + + @Override + protected Forecaster parse(XContentParser parser, GetResponse response) throws IOException { + return Forecaster.parse(parser, response.getId(), response.getVersion()); + } + + // TODO: add method body once backtesting implementation is ready + @Override + protected void confirmHistoricalRunning(String id, ActionListener listener) { + + } + + @Override + protected String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_SINGLE_STREAM_FORECASTERS_PREFIX_MSG, getMaxSingleStreamConfigs()); + } + + @Override + protected String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs) { + return String.format(Locale.ROOT, EXCEEDED_MAX_HC_FORECASTERS_PREFIX_MSG, getMaxHCConfigs()); + } + + @Override + protected String getNoDocsInUserIndexErrorMsg(String suppliedIndices) { + return String.format(Locale.ROOT, NO_DOCS_IN_USER_INDEX_MSG, suppliedIndices); + } + + @Override + protected String getDuplicateConfigErrorMsg(String name, List otherConfigIds) { + return String.format(Locale.ROOT, DUPLICATE_FORECASTER_MSG, name, otherConfigIds); + } + + @Override + protected Config copyConfig(User user, Config config) { + return new Forecaster( + config.getId(), + config.getVersion(), + config.getName(), + config.getDescription(), + config.getTimeField(), + config.getIndices(), + config.getFeatureAttributes(), + config.getFilterQuery(), + config.getInterval(), + config.getWindowDelay(), + config.getShingleSize(), + config.getUiMetadata(), + config.getSchemaVersion(), + Instant.now(), + config.getCategoryFields(), + user, + config.getCustomResultIndex(), + ((Forecaster) config).getHorizon(), + config.getImputationOption() + ); + } + + @SuppressWarnings("unchecked") + @Override + protected T createIndexConfigResponse(IndexResponse indexResponse, Config config) { + return (T) new IndexForecasterResponse( + indexResponse.getId(), + indexResponse.getVersion(), + indexResponse.getSeqNo(), + indexResponse.getPrimaryTerm(), + (Forecaster) config, + RestStatus.CREATED + ); + } + + @Override + protected Set getDefaultValidationType() { + return Sets.newHashSet(ValidationAspect.FORECASTER); + } + + @Override + protected String getFeatureErrorMsg(String name) { + return String.format(Locale.ROOT, VALIDATION_FEATURE_FAILURE, name); + } + + @Override + protected Integer getMaxSingleStreamConfigs() { + return maxSingleStreamForecasters; + } + + @Override + protected Integer getMaxHCConfigs() { + return maxHCForecasters; + } + + @Override + protected void validateModel(ActionListener listener) { + // TODO: add model validation and return with listener + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java new file mode 100644 index 000000000..b0fc6cdc8 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/ForecastIndexJobActionHandler.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.rest.handler; + +import static org.opensearch.forecast.model.ForecastTaskType.HISTORICAL_FORECASTER_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import java.util.List; + +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultRequest; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastIndexJobActionHandler extends + IndexJobActionHandler { + + public ForecastIndexJobActionHandler( + Client client, + ForecastIndexManagement indexManagement, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager adTaskManager, + ExecuteForecastResultResponseRecorder recorder, + NodeStateManager nodeStateManager, + Settings settings + ) { + super( + client, + indexManagement, + xContentRegistry, + adTaskManager, + recorder, + ForecastResultAction.INSTANCE, + AnalysisType.FORECAST, + ForecastIndex.STATE.getIndexName(), + StopForecasterAction.INSTANCE, + nodeStateManager, + settings, + FORECAST_REQUEST_TIMEOUT + ); + } + + @Override + protected ResultRequest createResultRequest(String configID, long start, long end) { + return new ForecastResultRequest(configID, start, end); + } + + @Override + protected List getHistorialConfigTaskTypes() { + return HISTORICAL_FORECASTER_TASK_TYPES; + } +} diff --git a/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java new file mode 100644 index 000000000..7d389dd59 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/rest/handler/IndexForecasterActionHandler.java @@ -0,0 +1,99 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.rest.handler; + +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.transport.IndexForecasterResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +/** + * process create/update anomaly detector request + * + */ +public class IndexForecasterActionHandler extends AbstractForecasterActionHandler { + /** + * Constructor function. + * + * @param clusterService ClusterService + * @param client OS node client that executes actions on the local node + * @param transportService OS transport service + * @param forecastIndices forecast index manager + * @param forecasterId forecaster identifier + * @param seqNo sequence number of last modification + * @param primaryTerm primary term of last modification + * @param refreshPolicy refresh policy + * @param forecaster forecaster instance + * @param requestTimeout request time out configuration + * @param maxSingleStreamForecasters max single-stream forecasters allowed + * @param maxHCForecasters max HC forecasters allowed + * @param maxForecastFeatures max features allowed per forecaster + * @param maxCategoricalFields max number of categorical fields + * @param method Rest Method type + * @param xContentRegistry Registry which is used for XContentParser + * @param user User context + */ + public IndexForecasterActionHandler( + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + TransportService transportService, + ForecastIndexManagement forecastIndices, + String forecasterId, + Long seqNo, + Long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + TimeValue requestTimeout, + Integer maxSingleStreamForecasters, + Integer maxHCForecasters, + Integer maxForecastFeatures, + Integer maxCategoricalFields, + RestRequest.Method method, + NamedXContentRegistry xContentRegistry, + User user, + SearchFeatureDao searchFeatureDao + ) { + super( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + user, + searchFeatureDao, + null, + false + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java index 1db9bf340..ff52a9d2b 100644 --- a/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java +++ b/src/main/java/org/opensearch/forecast/settings/ForecastEnabledSetting.java @@ -29,8 +29,6 @@ public class ForecastEnabledSetting extends DynamicNumericSetting { public static final String FORECAST_BREAKER_ENABLED = "plugins.forecast.breaker.enabled"; - public static final String FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED = "plugins.forecast.door_keeper_in_cache.enabled";; - public static final Map> settings = unmodifiableMap(new HashMap>() { { /** @@ -42,16 +40,6 @@ public class ForecastEnabledSetting extends DynamicNumericSetting { * forecast breaker enable/disable setting */ put(FORECAST_BREAKER_ENABLED, Setting.boolSetting(FORECAST_BREAKER_ENABLED, true, NodeScope, Dynamic)); - - /** - * We have a bloom filter placed in front of inactive entity cache to - * filter out unpopular items that are not likely to appear more - * than once. Whether this bloom filter is enabled or not. - */ - put( - FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, - Setting.boolSetting(FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, false, NodeScope, Dynamic) - ); } }); @@ -81,12 +69,4 @@ public static boolean isForecastEnabled() { public static boolean isForecastBreakerEnabled() { return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED); } - - /** - * If enabled, we filter out unpopular items that are not likely to appear more than once - * @return wWhether door keeper in cache is enabled or not. - */ - public static boolean isDoorKeeperInCacheEnabled() { - return ForecastEnabledSetting.getInstance().getSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED); - } } diff --git a/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java new file mode 100644 index 000000000..0ea00e129 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/task/ForecastTaskManager.java @@ -0,0 +1,689 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.task; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FORECASTER_IS_RUNNING; +import static org.opensearch.forecast.indices.ForecastIndexManagement.ALL_FORECAST_RESULTS_INDEX_PATTERN; +import static org.opensearch.forecast.model.ForecastTask.EXECUTION_START_TIME_FIELD; +import static org.opensearch.forecast.model.ForecastTask.FORECASTER_ID_FIELD; +import static org.opensearch.forecast.model.ForecastTask.IS_LATEST_FIELD; +import static org.opensearch.forecast.model.ForecastTask.PARENT_TASK_ID_FIELD; +import static org.opensearch.forecast.model.ForecastTask.TASK_TYPE_FIELD; +import static org.opensearch.forecast.model.ForecastTaskType.REALTIME_TASK_TYPES; +import static org.opensearch.forecast.settings.ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_BATCH_TASK_THREAD_POOL_NAME; +import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_ID_FIELD; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +public class ForecastTaskManager extends + TaskManager { + private final Logger logger = LogManager.getLogger(ForecastTaskManager.class); + + public ForecastTaskManager( + TaskCacheManager forecastTaskCacheManager, + Client client, + NamedXContentRegistry xContentRegistry, + ForecastIndexManagement forecastIndices, + ClusterService clusterService, + Settings settings, + ThreadPool threadPool, + NodeStateManager nodeStateManager + ) { + super( + forecastTaskCacheManager, + clusterService, + client, + ForecastIndex.STATE.getIndexName(), + ForecastTaskType.REALTIME_TASK_TYPES, + forecastIndices, + nodeStateManager, + AnalysisType.FORECAST, + xContentRegistry, + FORECASTER_ID_FIELD, + MAX_OLD_TASK_DOCS_PER_FORECASTER, + settings, + threadPool, + ALL_FORECAST_RESULTS_INDEX_PATTERN, + FORECAST_BATCH_TASK_THREAD_POOL_NAME + ); + } + + /** + * Init realtime task cache Realtime forecast depending on job scheduler to choose node (job coordinating node) + * to run forecast job. Nodes have primary or replica shard of the job index are candidate to run forecast job. + * Job scheduler will build hash ring on these candidate nodes and choose one to run forecast job. + * If forecast job index shard relocated, for example new node added into cluster, then job scheduler will + * rebuild hash ring and may choose different node to run forecast job. So we need to init realtime task cache + * on new forecast job coordinating node. + * + * If realtime task cache inited for the first time on this node, listener will return true; otherwise + * listener will return false. + * + * We don't clean up realtime task cache on old coordinating node as HourlyCron will clear cache on old coordinating node. + * + * @param forecasterId forecaster id + * @param forecaster forecaster + * @param transportService transport service + * @param listener listener + */ + @Override + public void initCacheWithCleanupIfRequired( + String forecasterId, + Config forecaster, + TransportService transportService, + ActionListener listener + ) { + try { + if (taskCacheManager.getRealtimeTaskCache(forecasterId) != null) { + listener.onResponse(false); + return; + } + + getAndExecuteOnLatestForecasterLevelTask(forecasterId, REALTIME_TASK_TYPES, (forecastTaskOptional) -> { + if (!forecastTaskOptional.isPresent()) { + logger.debug("Can't find realtime task for forecaster {}, init realtime task cache directly", forecasterId); + ExecutorFunction function = () -> createNewTask( + forecaster, + null, + forecaster.getUser(), + clusterService.localNode().getId(), + ActionListener.wrap(r -> { + logger.info("Recreate realtime task successfully for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, e -> { + logger.error("Failed to recreate realtime task for forecaster " + forecasterId, e); + listener.onFailure(e); + }) + ); + recreateRealtimeTaskBeforeExecuting(function, listener); + return; + } + + logger.info("Init realtime task cache for forecaster {}", forecasterId); + taskCacheManager.initRealtimeTaskCache(forecasterId, forecaster.getIntervalInMilliseconds()); + listener.onResponse(true); + }, transportService, false, listener); + } catch (Exception e) { + logger.error("Failed to init realtime task cache for " + forecasterId, e); + listener.onFailure(e); + } + } + + @Override + protected void deleteTaskDocs( + String forecasterId, + SearchRequest searchRequest, + ExecutorFunction function, + ActionListener listener + ) { + ActionListener searchListener = ActionListener.wrap(r -> { + Iterator iterator = r.getHits().iterator(); + if (iterator.hasNext()) { + BulkRequest bulkRequest = new BulkRequest(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ForecastTask forecastTask = ForecastTask.parse(parser, searchHit.getId()); + logger.debug("Delete old task: {} of forecaster: {}", forecastTask.getTaskId(), forecastTask.getConfigId()); + bulkRequest.add(new DeleteRequest(ForecastIndex.STATE.getIndexName()).id(forecastTask.getTaskId())); + } catch (Exception e) { + listener.onFailure(e); + } + } + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + logger.info("Old forecast tasks deleted for forecaster {}", forecasterId); + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); + // add deleted task in cache and delete its child tasks and forecast results + taskCacheManager.addDeletedTask(bulkItemResponse.getId()); + } + } + } + // delete child tasks and forecast results of this task + cleanChildTasksAndResultsOfDeletedTask(); + + function.execute(); + }, e -> { + logger.warn("Failed to clean forecast tasks for forecaster " + forecasterId, e); + listener.onFailure(e); + })); + } else { + function.execute(); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.execute(); + } else { + listener.onFailure(e); + } + }); + + client.search(searchRequest, searchListener); + } + + /** + * Update forecast task with specific fields. + * + * @param taskId forecast task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateForecastTask(String taskId, Map updatedFields) { + updateForecastTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated forecast task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update forecast task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update forecast task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateForecastTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(ForecastIndex.STATE.getIndexName(), taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + /** + * Get latest forecast task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param forecasterId detector id + * @param forecastTaskTypes forecast task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestForecasterLevelTask( + String forecasterId, + List forecastTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestForecastTask( + forecasterId, + null, + null, + forecastTaskTypes, + function, + transportService, + resetTaskState, + listener + ); + } + + /** + * Get one latest forecast task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param forecasterId forecaster id + * @param parentTaskId parent task id + * @param entity entity value + * @param forecastTaskTypes forecast task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestForecastTask( + String forecasterId, + String parentTaskId, + Entity entity, + List forecastTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestForecastTasks(forecasterId, parentTaskId, entity, forecastTaskTypes, (taskList) -> { + if (taskList != null && taskList.size() > 0) { + function.accept(Optional.ofNullable(taskList.get(0))); + } else { + function.accept(Optional.empty()); + } + }, transportService, resetTaskState, 1, listener); + } + + /** + * Get latest forecast tasks and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param forecasterId forecaster id + * @param parentTaskId parent task id + * @param entity entity value + * @param forecastTaskTypes forecast task types + * @param function consumer function + * @param transportService transport service + * @param size return how many AD tasks + * @param listener action listener + * @param response type of action listener + */ + public void getAndExecuteOnLatestForecastTasks( + String forecasterId, + String parentTaskId, + Entity entity, + List forecastTaskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + int size, + ActionListener listener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(FORECASTER_ID_FIELD, forecasterId)); + query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + } + if (forecastTaskTypes != null && forecastTaskTypes.size() > 0) { + query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, TaskType.taskTypeToString(forecastTaskTypes))); + } + if (entity != null && !ParseUtils.isNullOrEmpty(entity.getAttributes())) { + String path = "entity"; + String entityKeyFieldName = path + ".name"; + String entityValueFieldName = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); + TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); + + entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); + query.filter(nestedQueryBuilder); + } + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(ForecastIndex.STATE.getIndexName()); + + client.search(searchRequest, ActionListener.wrap(r -> { + // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 + // getTotalHits will be null when we track_total_hits is false in the query request. + // Add more checking here to cover some unknown cases. + List forecastTasks = new ArrayList<>(); + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + // don't throw exception here as consumer functions need to handle missing task + // in different way. + function.accept(forecastTasks); + return; + } + + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ForecastTask forecastTask = ForecastTask.parse(parser, searchHit.getId()); + forecastTasks.add(forecastTask); + } catch (Exception e) { + String message = "Failed to parse forecast task for forecaster " + forecasterId + ", task id " + searchHit.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + if (resetTaskState) { + resetLatestForecasterTaskState(forecastTasks, function, transportService, listener); + } else { + function.accept(forecastTasks); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.accept(new ArrayList<>()); + } else { + logger.error("Failed to search forecast task for forecaster " + forecasterId, e); + listener.onFailure(e); + } + })); + } + + /** + * Reset latest forecaster task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param forecastTasks ad tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + private void resetLatestForecasterTaskState( + List forecastTasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningHistoricalTasks = new ArrayList<>(); + List runningRealtimeTasks = new ArrayList<>(); + for (ForecastTask forecastTask : forecastTasks) { + if (!forecastTask.isEntityTask() && !forecastTask.isDone()) { + if (!forecastTask.isHistoricalTask()) { + // try to reset task state if realtime task is not ended + runningRealtimeTasks.add(forecastTask); + } else { + // try to reset task state if historical task not updated for 2 piece intervals + runningHistoricalTasks.add(forecastTask); + } + } + } + + // TODO: reset real time and historical tasks + } + + private void recreateRealtimeTaskBeforeExecuting(ExecutorFunction function, ActionListener listener) { + if (indexManagement.doesStateIndexExist()) { + function.execute(); + } else { + // If forecast state index doesn't exist, create index and execute function. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", ForecastIndex.STATE.getIndexName()); + function.execute(); + } else { + String error = String + .format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, ForecastIndex.STATE.getIndexName()); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + function.execute(); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } + + /** + * Poll deleted detector task from cache and delete its child tasks and AD results. + */ + public void cleanChildTasksAndResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteForecastResultsRequest = new DeleteByQueryRequest(ALL_FORECAST_RESULTS_INDEX_PATTERN); + deleteForecastResultsRequest.setQuery(new TermsQueryBuilder(TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteForecastResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted forecast results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(ForecastIndex.STATE.getIndexName()); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete forecast results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), AD_BATCH_TASK_THREAD_POOL_NAME); + } + + @Override + public void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + // TODO Auto-generated method stub + + } + + @Override + protected List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag) { + if (dateRange == null) { + return ForecastTaskType.REALTIME_TASK_TYPES; + } else if (resetLatestTaskStateFlag) { + return ForecastTaskType.ALL_HISTORICAL_TASK_TYPES; + } else { + return ForecastTaskType.HISTORICAL_FORECASTER_TASK_TYPES; + } + } + + @Override + protected TaskType getTaskType(Config config, DateRange dateRange) { + if (dateRange == null) { + return config.isHighCardinality() + ? ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER + : ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM; + } else { + return config.isHighCardinality() + ? ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER + : ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM; + } + } + + @Override + protected void createNewTask( + Config config, + DateRange dateRange, + User user, + String coordinatingNode, + ActionListener listener + ) { + String userName = user == null ? null : user.getName(); + Instant now = Instant.now(); + String taskType = getTaskType(config, dateRange).name(); + ForecastTask forecastTask = new ForecastTask.Builder() + .configId(config.getId()) + .forecaster((Forecaster) config) + .isLatest(true) + .taskType(taskType) + .executionStartTime(now) + .taskProgress(0.0f) + .initProgress(0.0f) + .state(TaskState.CREATED.name()) + .lastUpdateTime(now) + .startedBy(userName) + .coordinatingNode(coordinatingNode) + .dateRange(dateRange) + .user(user) + .build(); + + createTaskDirectly( + forecastTask, + r -> onIndexConfigTaskResponse( + r, + forecastTask, + (response, delegatedListener) -> cleanOldConfigTaskDocs(response, forecastTask, delegatedListener), + listener + ), + listener + ); + + } + + @Override + protected List getRealTimeTaskTypes() { + return ForecastTaskType.REALTIME_TASK_TYPES; + } + + @Override + public void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ) { + // no op for forecaster as we rely on state ttl to auto clean it + listener.onResponse(null); + } + + @Override + protected boolean isHistoricalHCTask(TimeSeriesTask task) { + return ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER.name().equals(task.getTaskType()); + } + + @Override + public void stopHistoricalAnalysis(String detectorId, Optional adTask, User user, ActionListener listener) { + // TODO Auto-generated method stub + + } + + @Override + protected void resetHistoricalConfigTaskState( + List runningHistoricalTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + // TODO complete reset historical part; now only execute function + function.execute(); + } + + @Override + protected void onIndexConfigTaskResponse( + IndexResponse response, + ForecastTask forecastTask, + BiConsumer> function, + ActionListener listener + ) { + if (response == null || response.getResult() != CREATED) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + forecastTask.setTaskId(response.getId()); + ActionListener delegatedListener = ActionListener.wrap(r -> { listener.onResponse(r); }, e -> { + handleTaskException(forecastTask, e); + if (e instanceof DuplicateTaskException) { + listener.onFailure(new OpenSearchStatusException(FORECASTER_IS_RUNNING, RestStatus.BAD_REQUEST)); + } else { + // TODO: For historical forecast task, what to do if any other exception happened? + // For realtime forecast, task cache will be inited when realtime job starts, check + // ForecastTaskManager#initRealtimeTaskCache for details. Here the + // realtime task cache not inited yet when create AD task, so no need to cleanup. + listener.onFailure(e); + } + }); + // TODO: what to do if this is a historical task? + if (function != null) { + function.accept(response, delegatedListener); + } + } + + @Override + protected void runBatchResultAction(IndexResponse response, ForecastTask tsTask, ActionListener listener) { + // TODO Auto-generated method stub + + } + + @Override + protected BiCheckedFunction getTaskParser() { + return ForecastTask::parse; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java new file mode 100644 index 000000000..eab816842 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.DeleteModelResponse; + +public class DeleteForecastModelAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "model/delete"; + public static final DeleteForecastModelAction INSTANCE = new DeleteForecastModelAction(); + + private DeleteForecastModelAction() { + super(NAME, DeleteModelResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java new file mode 100644 index 000000000..fad3bdd12 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/DeleteForecastModelTransportAction.java @@ -0,0 +1,58 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseDeleteModelTransportAction; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class DeleteForecastModelTransportAction extends + BaseDeleteModelTransportAction { + + @Inject + public DeleteForecastModelTransportAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + NodeStateManager nodeStateManager, + ForecastCacheProvider cache, + TaskCacheManager taskCacheManager, + ForecastColdStart coldStarter + ) { + super( + threadPool, + clusterService, + transportService, + actionFilters, + nodeStateManager, + cache, + taskCacheManager, + coldStarter, + DeleteForecastModelAction.NAME + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java new file mode 100644 index 000000000..77eec3d51 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class EntityForecastResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "entity/result"; + public static final EntityForecastResultAction INSTANCE = new EntityForecastResultAction(); + + private EntityForecastResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java new file mode 100644 index 000000000..6e6a79e0a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/EntityForecastResultTransportAction.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.EntityResultProcessor; +import org.opensearch.timeseries.transport.EntityResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +/** + * Entry-point for HC forecast workflow. We have created multiple queues for coordinating + * the workflow. The overrall workflow is: + * 1. We store as many frequently used entity models in a cache as allowed by the + * memory limit (by default 10% heap). If an entity feature is a hit, we use the in-memory model + * to forecast and record results using the result write queue. + * 2. If an entity feature is a miss, we check if there is free memory or any other + * entity's model can be evacuated. An in-memory entity's frequency may be lower + * compared to the cache miss entity. If that's the case, we replace the lower + * frequency entity's model with the higher frequency entity's model. To load the + * higher frequency entity's model, we first check if a model exists on disk by + * sending a checkpoint read queue request. If there is a checkpoint, we load it + * to memory, perform forecast, and save the result using the result write queue. + * Otherwise, we enqueue a cold start request to the cold start queue for model + * training. If training is successful, we save the learned model via the checkpoint + * write queue. + * 3. We also have the cold entity queue configured for cold entities, and the model + * training and inference are connected by serial juxtaposition to limit resource usage. + */ +public class EntityForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityForecastResultTransportAction.class); + private CircuitBreakerService circuitBreakerService; + private CacheProvider cache; + private final NodeStateManager stateManager; + private ThreadPool threadPool; + private EntityResultProcessor intervalDataProcessor; + + @Inject + public EntityForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + ForecastModelManager manager, + CircuitBreakerService adCircuitBreakerService, + ForecastCacheProvider entityCache, + NodeStateManager stateManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastColdEntityWorker coldEntityQueue, + ThreadPool threadPool, + ForecastColdStartWorker entityColdStartWorker, + Stats timeSeriesStats + ) { + super(EntityForecastResultAction.NAME, transportService, actionFilters, EntityResultRequest::new); + this.circuitBreakerService = adCircuitBreakerService; + this.cache = entityCache; + this.stateManager = stateManager; + this.threadPool = threadPool; + this.intervalDataProcessor = new EntityResultProcessor<>( + entityCache, + manager, + ForecastIndex.RESULT, + indexUtil, + resultWriteQueue, + ForecastResultWriteRequest.class, + timeSeriesStats, + entityColdStartWorker, + checkpointReadQueue, + coldEntityQueue + ); + } + + @Override + protected void doExecute(Task task, EntityResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + threadPool + .executor(TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME) + .execute(() -> cache.get().releaseMemoryForOpenCircuitBreaker()); + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String forecasterId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(forecasterId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", forecasterId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, forecasterId); + } + + stateManager + .getConfig( + forecasterId, + AnalysisType.FORECAST, + intervalDataProcessor.onGetConfig(listener, forecasterId, request, previousException) + ); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java new file mode 100644 index 000000000..449e9a791 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastResultAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/run"; + public static final ForecastResultAction INSTANCE = new ForecastResultAction(); + + private ForecastResultAction() { + super(NAME, ForecastResultResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java new file mode 100644 index 000000000..6394636b3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkAction.java @@ -0,0 +1,34 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.transport.TransportRequestOptions; + +public class ForecastResultBulkAction extends ActionType { + + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "write/bulk"; + public static final ForecastResultBulkAction INSTANCE = new ForecastResultBulkAction(); + + private ForecastResultBulkAction() { + super(NAME, ResultBulkResponse::new); + } + + @Override + public TransportRequestOptions transportOptions(Settings settings) { + return TransportRequestOptions.builder().withType(TransportRequestOptions.Type.BULK).build(); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java new file mode 100644 index 000000000..730275b4d --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkRequest.java @@ -0,0 +1,30 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.timeseries.transport.ResultBulkRequest; + +public class ForecastResultBulkRequest extends ResultBulkRequest { + + public ForecastResultBulkRequest() { + super(); + } + + public ForecastResultBulkRequest(StreamInput in) throws IOException { + super(in, ForecastResultWriteRequest::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java new file mode 100644 index 000000000..95422a98a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultBulkTransportAction.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT; + +import java.util.List; + +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.index.IndexingPressure; +import org.opensearch.timeseries.transport.ResultBulkTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecastResultBulkTransportAction extends + ResultBulkTransportAction { + + @Inject + public ForecastResultBulkTransportAction( + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + ClusterService clusterService, + Client client + ) { + super( + ForecastResultBulkAction.NAME, + transportService, + actionFilters, + indexingPressure, + settings, + client, + FORECAST_INDEX_PRESSURE_SOFT_LIMIT.get(settings), + FORECAST_INDEX_PRESSURE_HARD_LIMIT.get(settings), + ForecastIndex.RESULT.getIndexName(), + ForecastResultBulkRequest::new + ); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_SOFT_LIMIT, it -> softLimit = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_INDEX_PRESSURE_HARD_LIMIT, it -> hardLimit = it); + } + + @Override + protected BulkRequest prepareBulkRequest(float indexingPressurePercent, ForecastResultBulkRequest request) { + BulkRequest bulkRequest = new BulkRequest(); + List results = request.getAnomalyResults(); + + if (indexingPressurePercent <= softLimit) { + for (ForecastResultWriteRequest resultWriteRequest : results) { + addResult(bulkRequest, resultWriteRequest.getResult(), resultWriteRequest.getResultIndex()); + } + } else if (indexingPressurePercent <= hardLimit) { + // exceed soft limit (60%) but smaller than hard limit (90%) + float acceptProbability = 1 - indexingPressurePercent; + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (random.nextFloat() < acceptProbability) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } else { + // if exceeding hard limit, only index error result + for (ForecastResultWriteRequest resultWriteRequest : results) { + ForecastResult result = resultWriteRequest.getResult(); + if (result.isHighPriority()) { + addResult(bulkRequest, result, resultWriteRequest.getResultIndex()); + } + } + } + + return bulkRequest; + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java new file mode 100644 index 000000000..166936171 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultRequest.java @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.transport.ResultRequest; + +public class ForecastResultRequest extends ResultRequest { + public ForecastResultRequest(StreamInput in) throws IOException { + super(in); + } + + public ForecastResultRequest(String forecastID, long start, long end) { + super(forecastID, start, end); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(ForecastCommonMessages.FORECASTER_ID_MISSING_MSG, validationException); + } + if (start <= 0 || end <= 0 || start > end) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", ForecastCommonMessages.INVALID_TIMESTAMP_ERR_MSG, start, end), + validationException + ); + } + return validationException; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ForecastCommonName.ID_JSON_KEY, configId); + builder.field(CommonName.START_JSON_KEY, start); + builder.field(CommonName.END_JSON_KEY, end); + builder.endObject(); + return builder; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java new file mode 100644 index 000000000..568a331a3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultResponse.java @@ -0,0 +1,198 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.transport.ResultResponse; + +public class ForecastResultResponse extends ResultResponse { + public static final String DATA_QUALITY_JSON_KEY = "dataQuality"; + public static final String ERROR_JSON_KEY = "error"; + public static final String FEATURES_JSON_KEY = "features"; + public static final String FEATURE_VALUE_JSON_KEY = "value"; + public static final String RCF_TOTAL_UPDATES_JSON_KEY = "rcfTotalUpdates"; + public static final String FORECASTER_INTERVAL_IN_MINUTES_JSON_KEY = "forecasterIntervalInMinutes"; + public static final String FORECAST_VALUES_JSON_KEY = "forecastValues"; + public static final String FORECAST_UPPERS_JSON_KEY = "forecastUppers"; + public static final String FORECAST_LOWERS_JSON_KEY = "forecastLowers"; + + private Double dataQuality; + private float[] forecastsValues; + private float[] forecastsUppers; + private float[] forecastsLowers; + + // used when returning an error/exception or empty result + public ForecastResultResponse( + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster + ) { + this(Double.NaN, features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster, null, null, null); + } + + public ForecastResultResponse( + Double confidence, + List features, + String error, + Long rcfTotalUpdates, + Long forecasterIntervalInMinutes, + Boolean isHCForecaster, + float[] forecastsValues, + float[] forecastsUppers, + float[] forecastsLowers + ) { + super(features, error, rcfTotalUpdates, forecasterIntervalInMinutes, isHCForecaster); + this.dataQuality = confidence; + this.forecastsValues = forecastsValues; + this.forecastsUppers = forecastsUppers; + this.forecastsLowers = forecastsLowers; + } + + public ForecastResultResponse(StreamInput in) throws IOException { + super(in); + dataQuality = in.readDouble(); + int size = in.readVInt(); + features = new ArrayList(); + for (int i = 0; i < size; i++) { + features.add(new FeatureData(in)); + } + error = in.readOptionalString(); + rcfTotalUpdates = in.readOptionalLong(); + configIntervalInMinutes = in.readOptionalLong(); + isHC = in.readOptionalBoolean(); + + if (in.readBoolean()) { + forecastsValues = in.readFloatArray(); + forecastsUppers = in.readFloatArray(); + forecastsLowers = in.readFloatArray(); + } else { + forecastsValues = null; + forecastsUppers = null; + forecastsLowers = null; + } + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(dataQuality); + out.writeVInt(features.size()); + for (FeatureData feature : features) { + feature.writeTo(out); + } + out.writeOptionalString(error); + out.writeOptionalLong(rcfTotalUpdates); + out.writeOptionalLong(configIntervalInMinutes); + out.writeOptionalBoolean(isHC); + + if (forecastsValues != null) { + if (forecastsUppers == null || forecastsLowers == null) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "null value: forecastsUppers: %s, forecastsLowers: %s", forecastsUppers, forecastsLowers) + ); + } + out.writeBoolean(true); + out.writeFloatArray(forecastsValues); + out.writeFloatArray(forecastsUppers); + out.writeFloatArray(forecastsLowers); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DATA_QUALITY_JSON_KEY, dataQuality); + builder.field(ERROR_JSON_KEY, error); + builder.startArray(FEATURES_JSON_KEY); + for (FeatureData feature : features) { + feature.toXContent(builder, params); + } + builder.endArray(); + builder.field(RCF_TOTAL_UPDATES_JSON_KEY, rcfTotalUpdates); + builder.field(FORECASTER_INTERVAL_IN_MINUTES_JSON_KEY, configIntervalInMinutes); + builder.field(FORECAST_VALUES_JSON_KEY, forecastsValues); + builder.field(FORECAST_UPPERS_JSON_KEY, forecastsUppers); + builder.field(FORECAST_LOWERS_JSON_KEY, forecastsLowers); + builder.endObject(); + return builder; + } + + /** + * + * Convert ForecastResultResponse to ForecastResult + * + * @param forecastId Forecaster Id + * @param dataStartInstant data start time + * @param dataEndInstant data end time + * @param executionStartInstant execution start time + * @param executionEndInstant execution end time + * @param schemaVersion Schema version + * @param user Detector author + * @param error Error + * @return converted ForecastResult + */ + @Override + public List toIndexableResults( + String forecastId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ) { + // Forecast interval in milliseconds + long forecasterIntervalMilli = Duration.between(dataStartInstant, dataEndInstant).toMillis(); + return ForecastResult + .fromRawRCFCasterResult( + forecastId, + forecasterIntervalMilli, + dataQuality, + features, + dataStartInstant, + dataEndInstant, + executionStartInstant, + executionEndInstant, + error, + Optional.empty(), + user, + schemaVersion, + null, // single-stream real-time has no model id + forecastsValues, + forecastsUppers, + forecastsLowers, + null // real time results have no task id + ); + } + + @Override + public boolean shouldSave() { + return super.shouldSave() || (forecastsValues != null && forecastsValues.length > 0); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java new file mode 100644 index 000000000..19809c17e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastResultTransportAction.java @@ -0,0 +1,154 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.constant.ForecastCommonMessages; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(ForecastResultTransportAction.class); + private ForecastReultProcessor resultProcessor; + private final Client client; + private CircuitBreakerService adCircuitBreakerService; + // Cache HC forecaster id. This is used to count HC failure stats. We can tell a forecaster + // is HC or not by checking if forecaster id exists in this field or not. Will add + // forecaster id to this field when start to run realtime detection and remove forecaster + // id once realtime detection done. + private final Set hcForecasters; + private final Stats adStats; + private final NodeStateManager nodeStateManager; + + @Inject + public ForecastResultTransportAction( + ActionFilters actionFilters, + TransportService transportService, + Settings settings, + Client client, + SecurityClientUtil clientUtil, + NodeStateManager nodeStateManager, + FeatureManager featureManager, + ForecastModelManager modelManager, + HashRing hashRing, + ClusterService clusterService, + IndexNameExpressionResolver indexNameExpressionResolver, + CircuitBreakerService adCircuitBreakerService, + Stats forecastStats, + ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager realTimeTaskManager + ) { + super(ForecastResultAction.NAME, transportService, actionFilters, ForecastResultRequest::new); + this.resultProcessor = new ForecastReultProcessor( + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + TimeSeriesSettings.INTERVAL_RATIO_FOR_REQUESTS, + EntityForecastResultAction.NAME, + StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT, + settings, + clusterService, + threadPool, + hashRing, + nodeStateManager, + transportService, + forecastStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + ForecastResultResponse.class, + featureManager + ); + this.client = client; + this.adCircuitBreakerService = adCircuitBreakerService; + this.hcForecasters = new HashSet<>(); + this.adStats = forecastStats; + this.nodeStateManager = nodeStateManager; + } + + @Override + protected void doExecute(Task task, ForecastResultRequest request, ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + String forecastID = request.getConfigId(); + ActionListener original = listener; + listener = ActionListener.wrap(r -> { + hcForecasters.remove(forecastID); + original.onResponse(r); + }, e -> { + // If exception is AnomalyDetectionException and it should not be counted in stats, + // we will not count it in failure stats. + if (!(e instanceof TimeSeriesException) || ((TimeSeriesException) e).isCountedInStats()) { + adStats.getStat(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName()).increment(); + if (hcForecasters.contains(forecastID)) { + adStats.getStat(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName()).increment(); + } + } + hcForecasters.remove(forecastID); + original.onFailure(e); + }); + + if (!ForecastEnabledSetting.isForecastEnabled()) { + throw new EndRunException(forecastID, ForecastCommonMessages.DISABLED_ERR_MSG, true).countedInStats(false); + } + + adStats.getStat(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName()).increment(); + + if (adCircuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(forecastID, CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + try { + nodeStateManager + .getConfig( + forecastID, + AnalysisType.FORECAST, + resultProcessor.onGetConfig(listener, forecastID, request, hcForecasters) + ); + } catch (Exception ex) { + ResultProcessor.handleExecuteException(ex, listener, forecastID); + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastReultProcessor.java b/src/main/java/org/opensearch/forecast/transport/ForecastReultProcessor.java new file mode 100644 index 000000000..c6db38623 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastReultProcessor.java @@ -0,0 +1,209 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_PAGE_SIZE; + +import java.util.ArrayList; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.ResultProcessor; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.SingleStreamResultRequest; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class ForecastReultProcessor extends + ResultProcessor { + + private static final Logger LOG = LogManager.getLogger(ForecastReultProcessor.class); + + public ForecastReultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + Stats timeSeriesStats, + ForecastTaskManager realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager + ) { + super( + requestTimeoutSetting, + intervalRatioForRequests, + entityResultAction, + hcRequestCountStat, + settings, + clusterService, + threadPool, + TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME, + hashRing, + nodeStateManager, + transportService, + timeSeriesStats, + realTimeTaskManager, + xContentRegistry, + client, + clientUtil, + indexNameExpressionResolver, + transportResultResponseClazz, + featureManager, + FORECAST_MAX_ENTITIES_PER_INTERVAL, + FORECAST_PAGE_SIZE, + AnalysisType.FORECAST + ); + } + + @Override + protected ActionListener onFeatureResponseForSingleStreamConfig( + String forecasterId, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime + ) { + return ActionListener.wrap(featureOptional -> { + Optional previousException = nodeStateManager.fetchExceptionAndClear(forecasterId); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous forecast exception of [{}]", forecasterId), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + Forecaster forecaster = (Forecaster) config; + + if (featureOptional.getUnprocessedFeatures().isEmpty()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, forecasterId); + listener + .onResponse( + ResultResponse + .create( + new ArrayList(), + "No data in current window", + null, + null, + false, + transportResultResponseClazz + ) + ); + return; + } + + final AtomicReference failure = new AtomicReference(); + + LOG.info("Sending forecast single stream request to {} for model {}", rcfNode.getId(), rcfModelId); + + transportService + .sendRequest( + rcfNode, + ForecastSingleStreamResultAction.NAME, + new SingleStreamResultRequest( + forecasterId, + rcfModelId, + dataStartTime, + dataEndTime, + featureOptional.getUnprocessedFeatures().get() + ), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(rcfNode.getId(), forecasterId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else if (!featureOptional.getUnprocessedFeatures().isPresent()) { + // Feature not available is common when we have data holes. Respond empty response + // and don't log to avoid bloating our logs. + LOG.debug("No data in current window between {} and {} for {}", dataStartTime, dataEndTime, forecasterId); + listener + .onResponse( + ResultResponse + .create( + new ArrayList(), + "No data in current window", + null, + null, + false, + transportResultResponseClazz + ) + ); + } else { + listener + .onResponse( + ResultResponse + .create( + new ArrayList(), + null, + null, + forecaster.getIntervalInMinutes(), + true, + transportResultResponseClazz + ) + ); + } + }, exception -> { handleQueryFailure(exception, listener, forecasterId); }); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java new file mode 100644 index 000000000..6b8687b82 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class ForecastSingleStreamResultAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "singlestream/result"; + public static final ForecastSingleStreamResultAction INSTANCE = new ForecastSingleStreamResultAction(); + + private ForecastSingleStreamResultAction() { + super(NAME, AcknowledgedResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java new file mode 100644 index 000000000..54988ff3c --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecastSingleStreamResultTransportAction.java @@ -0,0 +1,237 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import java.time.Instant; +import java.util.List; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.ml.RCFCasterResult; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.SingleStreamResultRequest; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.transport.TransportService; + +import com.amazon.randomcutforest.parkservices.RCFCaster; + +public class ForecastSingleStreamResultTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(EntityForecastResultTransportAction.class); + private CircuitBreakerService circuitBreakerService; + private ForecastCacheProvider cache; + private final NodeStateManager stateManager; + private ForecastCheckpointReadWorker checkpointReadQueue; + private ForecastModelManager modelManager; + private ForecastIndexManagement indexUtil; + private ForecastResultWriteWorker resultWriteQueue; + private Stats stats; + private ForecastColdStartWorker forecastColdStartQueue; + + @Inject + public ForecastSingleStreamResultTransportAction( + TransportService transportService, + ActionFilters actionFilters, + CircuitBreakerService circuitBreakerService, + ForecastCacheProvider cache, + NodeStateManager stateManager, + ForecastCheckpointReadWorker checkpointReadQueue, + ForecastModelManager modelManager, + ForecastIndexManagement indexUtil, + ForecastResultWriteWorker resultWriteQueue, + Stats stats, + ForecastColdStartWorker forecastColdStartQueue + ) { + super(ForecastSingleStreamResultAction.NAME, transportService, actionFilters, SingleStreamResultRequest::new); + this.circuitBreakerService = circuitBreakerService; + this.cache = cache; + this.stateManager = stateManager; + this.checkpointReadQueue = checkpointReadQueue; + this.modelManager = modelManager; + this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.stats = stats; + this.forecastColdStartQueue = forecastColdStartQueue; + } + + @Override + protected void doExecute(Task task, SingleStreamResultRequest request, ActionListener listener) { + if (circuitBreakerService.isOpen()) { + listener.onFailure(new LimitExceededException(request.getConfigId(), CommonMessages.MEMORY_CIRCUIT_BROKEN_ERR_MSG, false)); + return; + } + + try { + String forecasterId = request.getConfigId(); + + Optional previousException = stateManager.fetchExceptionAndClear(forecasterId); + + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error("Previous exception of {}: {}", forecasterId, exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + + listener = ExceptionUtil.wrapListener(listener, exception, forecasterId); + } + + stateManager.getConfig(forecasterId, AnalysisType.FORECAST, onGetConfig(listener, forecasterId, request, previousException)); + } catch (Exception exception) { + LOG.error("fail to get entity's anomaly grade", exception); + listener.onFailure(exception); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + SingleStreamResultRequest request, + Optional prevException + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + Instant executionStartTime = Instant.now(); + + String modelId = request.getModelId(); + double[] datapoint = request.getDataPoint(); + ModelState modelState = cache.get().get(modelId, config); + if (modelState == null) { + // cache miss + checkpointReadQueue + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + request.getModelId(), + datapoint, + request.getStart() + ) + ); + } else { + try { + RCFCasterResult result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + modelState, + modelId, + Optional.empty(), + config + ); + // result.getRcfScore() = 0 means the model is not initialized + if (result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + ParseUtils.getFeatureData(datapoint, config), + Optional.empty(), + indexUtil.getSchemaVersion(ForecastIndex.RESULT), + modelId, + null, + null + ); + + for (ForecastResult r : indexableResults) { + resultWriteQueue + .put( + ResultWriteRequest + .create( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + r, + config.getCustomResultIndex(), + ForecastResultWriteRequest.class + ) + ); + } + } + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + forecastColdStartQueue + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + modelId, + datapoint, + request.getStart() + ) + ); + } + } + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java new file mode 100644 index 000000000..bfd915288 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.JobResponse; + +public class ForecasterJobAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/jobmanagement"; + public static final ForecasterJobAction INSTANCE = new ForecasterJobAction(); + + private ForecasterJobAction() { + super(NAME, JobResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java new file mode 100644 index 000000000..02665891c --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/ForecasterJobTransportAction.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_START_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_REQUEST_TIMEOUT; + +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseJobTransportAction; +import org.opensearch.transport.TransportService; + +public class ForecasterJobTransportAction extends + BaseJobTransportAction< + ForecastIndex, + ForecastIndexManagement, + TaskCacheManager, + ForecastTaskType, + ForecastTask, + ForecastTaskManager, + ForecastResult, + ExecuteForecastResultResponseRecorder, + ForecastIndexJobActionHandler + > { + + @Inject + public ForecasterJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastIndexJobActionHandler forecastIndexJobActionHandler + ) { + super( + transportService, + actionFilters, + client, + clusterService, + settings, + xContentRegistry, + FORECAST_FILTER_BY_BACKEND_ROLES, + ForecasterJobAction.NAME, + FORECAST_REQUEST_TIMEOUT, + FAIL_TO_START_FORECASTER, + FAIL_TO_STOP_FORECASTER, + Forecaster.class, + forecastIndexJobActionHandler + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java new file mode 100644 index 000000000..ef5d13540 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterAction.java @@ -0,0 +1,26 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class GetForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecasters/get"; + public static final GetForecasterAction INSTANCE = new GetForecasterAction(); + + private GetForecasterAction() { + super(NAME, GetForecasterResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java new file mode 100644 index 000000000..ab35a04e3 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterResponse.java @@ -0,0 +1,192 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class GetForecasterResponse extends ActionResponse implements ToXContentObject { + + public static final String FORECASTER_PROFILE = "forecasterProfile"; + public static final String ENTITY_PROFILE = "entityProfile"; + private String id; + private long version; + private long primaryTerm; + private long seqNo; + private Forecaster forecaster; + private Job forecastJob; + private ForecastTask realtimeTask; + private ForecastTask historicalTask; + private RestStatus restStatus; + // TODO: add forecaster and entity profile + // private DetectorProfile detectorProfile; + // private EntityProfile entityProfile; + private boolean profileResponse; + private boolean returnJob; + private boolean returnTask; + + public GetForecasterResponse(StreamInput in) throws IOException { + super(in); + profileResponse = in.readBoolean(); + if (profileResponse) { + + } else { + id = in.readString(); + version = in.readLong(); + primaryTerm = in.readLong(); + seqNo = in.readLong(); + restStatus = in.readEnum(RestStatus.class); + forecaster = new Forecaster(in); + returnJob = in.readBoolean(); + if (returnJob) { + forecastJob = new Job(in); + } else { + forecastJob = null; + } + returnTask = in.readBoolean(); + if (in.readBoolean()) { + realtimeTask = new ForecastTask(in); + } else { + realtimeTask = null; + } + if (in.readBoolean()) { + historicalTask = new ForecastTask(in); + } else { + historicalTask = null; + } + } + + } + + public GetForecasterResponse( + String id, + long version, + long primaryTerm, + long seqNo, + Forecaster forecaster, + Job job, + boolean returnJob, + ForecastTask realtimeTask, + ForecastTask historicalTask, + boolean returnTask, + RestStatus restStatus, + boolean profileResponse + ) { + this.id = id; + this.version = version; + this.primaryTerm = primaryTerm; + this.seqNo = seqNo; + this.forecaster = forecaster; + this.forecastJob = job; + this.returnJob = returnJob; + if (this.returnJob) { + this.forecastJob = job; + } else { + this.forecastJob = null; + } + this.returnTask = returnTask; + if (this.returnTask) { + this.realtimeTask = realtimeTask; + this.historicalTask = historicalTask; + } else { + this.realtimeTask = null; + this.historicalTask = null; + } + this.restStatus = restStatus; + this.profileResponse = profileResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + if (profileResponse) { + out.writeBoolean(true); // profileResponse is true + + } else { + out.writeBoolean(false); // profileResponse is false + out.writeString(id); + out.writeLong(version); + out.writeLong(primaryTerm); + out.writeLong(seqNo); + out.writeEnum(restStatus); + forecaster.writeTo(out); + if (returnJob) { + out.writeBoolean(true); // returnJob is true + forecastJob.writeTo(out); + } else { + out.writeBoolean(false); // returnJob is false + } + out.writeBoolean(returnTask); + if (realtimeTask != null) { + out.writeBoolean(true); + realtimeTask.writeTo(out); + } else { + out.writeBoolean(false); + } + if (historicalTask != null) { + out.writeBoolean(true); + historicalTask.writeTo(out); + } else { + out.writeBoolean(false); + } + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (profileResponse) { + + } else { + builder.startObject(); + builder.field(RestHandlerUtils._ID, id); + builder.field(RestHandlerUtils._VERSION, version); + builder.field(RestHandlerUtils._PRIMARY_TERM, primaryTerm); + builder.field(RestHandlerUtils._SEQ_NO, seqNo); + builder.field(RestHandlerUtils.REST_STATUS, restStatus); + builder.field(RestHandlerUtils.FORECASTER, forecaster); + if (returnJob) { + builder.field(RestHandlerUtils.FORECASTER_JOB, forecastJob); + } + if (returnTask) { + builder.field(RestHandlerUtils.REALTIME_TASK, realtimeTask); + builder.field(RestHandlerUtils.HISTORICAL_ANALYSIS_TASK, historicalTask); + } + builder.endObject(); + } + return builder; + } + + public Job getForecastJob() { + return forecastJob; + } + + public ForecastTask getRealtimeTask() { + return realtimeTask; + } + + public ForecastTask getHistoricalTask() { + return historicalTask; + } + + public Forecaster getForecaster() { + return forecaster; + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java new file mode 100644 index 000000000..e82ba846f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/GetForecasterTransportAction.java @@ -0,0 +1,127 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.util.Optional; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.ForecastTask; +import org.opensearch.forecast.model.ForecastTaskType; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class GetForecasterTransportAction extends + BaseGetConfigTransportAction { + + // private final Set allProfileTypeStrs; + // private final Set allProfileTypes; + // private final Set defaultDetectorProfileTypes; + // private final Set allEntityProfileTypeStrs; + // private final Set allEntityProfileTypes; + // private final Set defaultEntityProfileTypes; + + @Inject + public GetForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + ForecastTaskManager forecastTaskManager + ) { + super( + transportService, + nodeFilter, + actionFilters, + clusterService, + client, + clientUtil, + settings, + xContentRegistry, + forecastTaskManager, + GetForecasterAction.NAME, + Forecaster.class, + Forecaster.FORECAST_PARSE_FIELD_NAME, + ForecastTaskType.ALL_FORECAST_TASK_TYPES, + ForecastTaskType.FORECAST_REALTIME_HC_FORECASTER.name(), + ForecastTaskType.FORECAST_REALTIME_SINGLE_STREAM.name(), + ForecastTaskType.FORECAST_HISTORICAL_HC_FORECASTER.name(), + ForecastTaskType.FORECAST_HISTORICAL_SINGLE_STREAM.name(), + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + GetForecasterResponse.class + ); + } + + @Override + protected void getExecuteProfile( + GetConfigRequest request, + Entity entity, + String typesStr, + boolean all, + String configId, + ActionListener listener + ) { + // TODO Auto-generated method stub + + } + + @Override + protected GetForecasterResponse createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + Forecaster config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus + ) { + return new GetForecasterResponse( + id, + version, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask.orElse(null), + historicalTask.orElse(null), + returnTask, + restStatus, + false + ); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java new file mode 100644 index 000000000..23613a89f --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterAction.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; + +public class IndexForecasterAction extends ActionType { + // External Action which used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.EXTERNAL_ACTION_PREFIX + "forecaster/write"; + public static final IndexForecasterAction INSTANCE = new IndexForecasterAction(); + + private IndexForecasterAction() { + super(NAME, IndexForecasterResponse::new); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java new file mode 100644 index 000000000..60a3a1964 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterRequest.java @@ -0,0 +1,144 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.rest.RestRequest; + +public class IndexForecasterRequest extends ActionRequest { + private String forecastID; + private long seqNo; + private long primaryTerm; + private WriteRequest.RefreshPolicy refreshPolicy; + private Forecaster forecaster; + private RestRequest.Method method; + private TimeValue requestTimeout; + private Integer maxSingleStreamForecasters; + private Integer maxHCForecasters; + private Integer maxForecastFeatures; + private Integer maxCategoricalFields; + + public IndexForecasterRequest(StreamInput in) throws IOException { + super(in); + forecastID = in.readString(); + seqNo = in.readLong(); + primaryTerm = in.readLong(); + refreshPolicy = in.readEnum(WriteRequest.RefreshPolicy.class); + forecaster = new Forecaster(in); + method = in.readEnum(RestRequest.Method.class); + requestTimeout = in.readTimeValue(); + maxSingleStreamForecasters = in.readInt(); + maxHCForecasters = in.readInt(); + maxForecastFeatures = in.readInt(); + maxCategoricalFields = in.readInt(); + } + + public IndexForecasterRequest( + String forecasterID, + long seqNo, + long primaryTerm, + WriteRequest.RefreshPolicy refreshPolicy, + Forecaster forecaster, + RestRequest.Method method, + TimeValue requestTimeout, + Integer maxSingleEntityAnomalyDetectors, + Integer maxMultiEntityAnomalyDetectors, + Integer maxAnomalyFeatures, + Integer maxCategoricalFields + ) { + super(); + this.forecastID = forecasterID; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.refreshPolicy = refreshPolicy; + this.forecaster = forecaster; + this.method = method; + this.requestTimeout = requestTimeout; + this.maxSingleStreamForecasters = maxSingleEntityAnomalyDetectors; + this.maxHCForecasters = maxMultiEntityAnomalyDetectors; + this.maxForecastFeatures = maxAnomalyFeatures; + this.maxCategoricalFields = maxCategoricalFields; + } + + public String getForecasterID() { + return forecastID; + } + + public long getSeqNo() { + return seqNo; + } + + public long getPrimaryTerm() { + return primaryTerm; + } + + public WriteRequest.RefreshPolicy getRefreshPolicy() { + return refreshPolicy; + } + + public Forecaster getForecaster() { + return forecaster; + } + + public RestRequest.Method getMethod() { + return method; + } + + public TimeValue getRequestTimeout() { + return requestTimeout; + } + + public Integer getMaxSingleStreamForecasters() { + return maxSingleStreamForecasters; + } + + public Integer getMaxHCForecasters() { + return maxHCForecasters; + } + + public Integer getMaxForecastFeatures() { + return maxForecastFeatures; + } + + public Integer getMaxCategoricalFields() { + return maxCategoricalFields; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(forecastID); + out.writeLong(seqNo); + out.writeLong(primaryTerm); + out.writeEnum(refreshPolicy); + forecaster.writeTo(out); + out.writeEnum(method); + out.writeTimeValue(requestTimeout); + out.writeInt(maxSingleStreamForecasters); + out.writeInt(maxHCForecasters); + out.writeInt(maxForecastFeatures); + out.writeInt(maxCategoricalFields); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java similarity index 76% rename from src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java rename to src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java index 157d50000..e3073afaf 100644 --- a/src/main/java/org/opensearch/ad/transport/AnomalyDetectorJobResponse.java +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.forecast.transport; import java.io.IOException; @@ -19,29 +19,33 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyDetectorJobResponse extends ActionResponse implements ToXContentObject { +public class IndexForecasterResponse extends ActionResponse implements ToXContentObject { private final String id; private final long version; private final long seqNo; private final long primaryTerm; + private final Forecaster forecaster; private final RestStatus restStatus; - public AnomalyDetectorJobResponse(StreamInput in) throws IOException { + public IndexForecasterResponse(StreamInput in) throws IOException { super(in); id = in.readString(); version = in.readLong(); seqNo = in.readLong(); primaryTerm = in.readLong(); + forecaster = new Forecaster(in); restStatus = in.readEnum(RestStatus.class); } - public AnomalyDetectorJobResponse(String id, long version, long seqNo, long primaryTerm, RestStatus restStatus) { + public IndexForecasterResponse(String id, long version, long seqNo, long primaryTerm, Forecaster forecaster, RestStatus restStatus) { this.id = id; this.version = version; this.seqNo = seqNo; this.primaryTerm = primaryTerm; + this.forecaster = forecaster; this.restStatus = restStatus; } @@ -55,6 +59,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(version); out.writeLong(seqNo); out.writeLong(primaryTerm); + forecaster.writeTo(out); out.writeEnum(restStatus); } @@ -65,6 +70,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(RestHandlerUtils._ID, id) .field(RestHandlerUtils._VERSION, version) .field(RestHandlerUtils._SEQ_NO, seqNo) + .field(RestHandlerUtils.FORECASTER, forecaster) .field(RestHandlerUtils._PRIMARY_TERM, primaryTerm) .endObject(); } diff --git a/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java new file mode 100644 index 000000000..18e7e379e --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/IndexForecasterTransportAction.java @@ -0,0 +1,210 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_CREATE_FORECASTER; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_UPDATE_FORECASTER; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.forecast.util.ForecastParseUtils.checkFilterByBackendRoles; +import static org.opensearch.forecast.util.ForecastParseUtils.getForecaster; +import static org.opensearch.forecast.util.ForecastParseUtils.getUserContext; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.List; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.rest.handler.IndexForecasterActionHandler; +import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public class IndexForecasterTransportAction extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(IndexForecasterTransportAction.class); + private final Client client; + private final SecurityClientUtil clientUtil; + private final TransportService transportService; + private final ForecastIndexManagement forecastIndices; + private final ClusterService clusterService; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final SearchFeatureDao searchFeatureDao; + + @Inject + public IndexForecasterTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + SecurityClientUtil clientUtil, + ClusterService clusterService, + Settings settings, + ForecastIndexManagement forecastIndices, + NamedXContentRegistry xContentRegistry, + SearchFeatureDao searchFeatureDao + ) { + super(IndexForecasterAction.NAME, transportService, actionFilters, IndexForecasterRequest::new); + this.client = client; + this.clientUtil = clientUtil; + this.transportService = transportService; + this.clusterService = clusterService; + this.forecastIndices = forecastIndices; + this.xContentRegistry = xContentRegistry; + filterByEnabled = ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(FORECAST_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + this.searchFeatureDao = searchFeatureDao; + } + + @Override + protected void doExecute(Task task, IndexForecasterRequest request, ActionListener actionListener) { + User user = getUserContext(client); + String forecasterId = request.getForecasterID(); + RestRequest.Method method = request.getMethod(); + String errorMessage = method == RestRequest.Method.PUT ? FAIL_TO_UPDATE_FORECASTER : FAIL_TO_CREATE_FORECASTER; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + forecasterId, + method, + listener, + (forecaster) -> forecastExecute(request, user, forecaster, context, listener) + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private void resolveUserAndExecute( + User requestedUser, + String forecasterId, + RestRequest.Method method, + ActionListener listener, + Consumer function + ) { + try { + // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to + // check if request user have access to the forecaster or not. But we still need to get current forecaster for + // this case, so we can keep current forecaster's user data. + boolean filterByBackendRole = requestedUser == null ? false : filterByEnabled; + + // Check if user has backend roles + // When filter by is enabled, block users creating/updating detectors who do not have backend roles. + if (filterByBackendRole && !checkFilterByBackendRoles(requestedUser, listener)) { + return; + } + if (method == RestRequest.Method.PUT) { + // Update forecaster request, check if user has permissions to update the forecaster + // Get forecaster and verify backend roles + getForecaster( + requestedUser, + forecasterId, + listener, + function, + client, + clusterService, + xContentRegistry, + filterByBackendRole + ); + } else { + // Create Detector. No need to get current detector. + function.accept(null); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void forecastExecute( + IndexForecasterRequest request, + User user, + Forecaster currentForecaster, + ThreadContext.StoredContext storedContext, + ActionListener listener + ) { + forecastIndices.update(); + String forecasterId = request.getForecasterID(); + long seqNo = request.getSeqNo(); + long primaryTerm = request.getPrimaryTerm(); + WriteRequest.RefreshPolicy refreshPolicy = request.getRefreshPolicy(); + Forecaster forecaster = request.getForecaster(); + RestRequest.Method method = request.getMethod(); + TimeValue requestTimeout = request.getRequestTimeout(); + Integer maxSingleStreamForecasters = request.getMaxSingleStreamForecasters(); + Integer maxHCForecasters = request.getMaxHCForecasters(); + Integer maxForecastFeatures = request.getMaxForecastFeatures(); + Integer maxCategoricalFields = request.getMaxCategoricalFields(); + + storedContext.restore(); + checkIndicesAndExecute(forecaster.getIndices(), () -> { + // Don't replace forecaster's user when update detector + // Github issue: https://github.com/opensearch-project/anomaly-detection/issues/124 + User forecastUser = currentForecaster == null ? user : currentForecaster.getUser(); + IndexForecasterActionHandler indexAnomalyDetectorActionHandler = new IndexForecasterActionHandler( + clusterService, + client, + clientUtil, + transportService, + forecastIndices, + forecasterId, + seqNo, + primaryTerm, + refreshPolicy, + forecaster, + requestTimeout, + maxSingleStreamForecasters, + maxHCForecasters, + maxForecastFeatures, + maxCategoricalFields, + method, + xContentRegistry, + forecastUser, + searchFeatureDao + ); + indexAnomalyDetectorActionHandler.start(listener); + }, listener); + } + + private void checkIndicesAndExecute(List indices, ExecutorFunction function, ActionListener listener) { + SearchRequest searchRequest = new SearchRequest() + .indices(indices.toArray(new String[0])) + .source(new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery())); + client.search(searchRequest, ActionListener.wrap(r -> { function.execute(); }, e -> { + // Due to below issue with security plugin, we get security_exception when invalid index name is mentioned. + // https://github.com/opendistro-for-elasticsearch/security/issues/718 + LOG.error(e); + listener.onFailure(e); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java new file mode 100644 index 000000000..9b38db2eb --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterAction.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.forecast.constant.ForecastCommonValue; +import org.opensearch.timeseries.transport.StopConfigResponse; + +public class StopForecasterAction extends ActionType { + // Internal Action which is not used for public facing RestAPIs. + public static final String NAME = ForecastCommonValue.INTERNAL_ACTION_PREFIX + "forecaster/stop"; + public static final StopForecasterAction INSTANCE = new StopForecasterAction(); + + private StopForecasterAction() { + super(NAME, StopConfigResponse::new); + } + +} diff --git a/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java new file mode 100644 index 000000000..ed751085b --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/StopForecasterTransportAction.java @@ -0,0 +1,85 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport; + +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_STOP_FORECASTER; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.inject.Inject; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.transport.TransportService; + +public class StopForecasterTransportAction extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(StopForecasterTransportAction.class); + + private final Client client; + private final DiscoveryNodeFilterer nodeFilter; + + @Inject + public StopForecasterTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + Client client + ) { + super(StopForecasterAction.NAME, transportService, actionFilters, StopConfigRequest::new); + this.client = client; + this.nodeFilter = nodeFilter; + } + + @Override + protected void doExecute(Task task, ActionRequest actionRequest, ActionListener listener) { + StopConfigRequest request = StopConfigRequest.fromActionRequest(actionRequest); + String configId = request.getConfigID(); + try { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + DeleteModelRequest modelDeleteRequest = new DeleteModelRequest(configId, dataNodes); + client.execute(DeleteForecastModelAction.INSTANCE, modelDeleteRequest, ActionListener.wrap(response -> { + if (response.hasFailures()) { + LOG.warn("Cannot delete all models of forecaster {}", configId); + for (FailedNodeException failedNodeException : response.failures()) { + LOG.warn("Deleting models of node has exception", failedNodeException); + } + // if customers are using an updated detector and we haven't deleted old + // checkpoints, customer would have trouble + listener.onResponse(new StopConfigResponse(false)); + } else { + LOG.info("models of forecaster {} get deleted", configId); + listener.onResponse(new StopConfigResponse(true)); + } + }, exception -> { + LOG.error(new ParameterizedMessage("Deletion of forecaster [{}] has exception.", configId), exception); + listener.onResponse(new StopConfigResponse(false)); + })); + } catch (Exception e) { + LOG.error(FAIL_TO_STOP_FORECASTER + " " + configId, e); + Throwable cause = ExceptionsHelper.unwrapCause(e); + listener.onFailure(new InternalFailure(configId, FAIL_TO_STOP_FORECASTER, cause)); + } + } +} diff --git a/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..43ed022a1 --- /dev/null +++ b/src/main/java/org/opensearch/forecast/transport/handler/ForecastIndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,51 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.forecast.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkRequest; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +public class ForecastIndexMemoryPressureAwareResultHandler extends + IndexMemoryPressureAwareResultHandler { + private static final Logger LOG = LogManager.getLogger(ForecastIndexMemoryPressureAwareResultHandler.class); + + @Inject + public ForecastIndexMemoryPressureAwareResultHandler(Client client, ForecastIndexManagement anomalyDetectionIndices) { + super(client, anomalyDetectionIndices); + } + + @Override + public void bulk(ForecastResultBulkRequest currentBulkRequest, ActionListener listener) { + if (currentBulkRequest.numberOfActions() <= 0) { + listener.onFailure(new TimeSeriesException("no result to save")); + return; + } + client.execute(ForecastResultBulkAction.INSTANCE, currentBulkRequest, ActionListener.wrap(response -> { + LOG.debug(CommonMessages.SUCCESS_SAVING_RESULT_MSG); + listener.onResponse(response); + }, exception -> { + LOG.error("Error in bulking results", exception); + listener.onFailure(exception); + })); + } +} diff --git a/src/main/java/org/opensearch/forecast/util/ForecastParseUtils.java b/src/main/java/org/opensearch/forecast/util/ForecastParseUtils.java new file mode 100644 index 000000000..59191640a --- /dev/null +++ b/src/main/java/org/opensearch/forecast/util/ForecastParseUtils.java @@ -0,0 +1,155 @@ +package org.opensearch.forecast.util; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_FIND_FORECASTER_MSG; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_GET_USER_INFO; +import static org.opensearch.forecast.constant.ForecastCommonMessages.NO_PERMISSION_TO_ACCESS_FORECASTER; + +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.Forecaster; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class ForecastParseUtils { + private static final Logger logger = LogManager.getLogger(ForecastParseUtils.class); + + public static boolean checkFilterByBackendRoles(User requestedUser, ActionListener listener) { + if (requestedUser == null) { + return false; + } + if (requestedUser.getBackendRoles().isEmpty()) { + listener + .onFailure( + new TimeSeriesException( + "Filter by backend roles is enabled and User " + requestedUser.getName() + " does not have backend roles configured" + ) + ); + return false; + } + return true; + } + + /** + * If filterByBackendRole is true, get forecaster and check if the user has permissions to access the forecaster, + * then execute function; otherwise, get forecaster and execute function + * @param requestUser user from request + * @param forecasterId forecaster id + * @param listener action listener + * @param function consumer function + * @param client client + * @param clusterService cluster service + * @param xContentRegistry XContent registry + * @param filterByBackendRole filter by backend role or not + */ + public static void getForecaster( + User requestUser, + String forecasterId, + ActionListener listener, + Consumer function, + Client client, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + boolean filterByBackendRole + ) { + if (clusterService.state().metadata().indices().containsKey(ForecastIndex.CONFIG.getIndexName())) { + GetRequest request = new GetRequest(ForecastIndex.CONFIG.getIndexName()).id(forecasterId); + client + .get( + request, + ActionListener + .wrap( + response -> onGetAdResponse( + response, + requestUser, + forecasterId, + listener, + function, + xContentRegistry, + filterByBackendRole + ), + exception -> { + logger.error("Failed to get forecaster: " + forecasterId, exception); + listener.onFailure(exception); + } + ) + ); + } else { + listener.onFailure(new IndexNotFoundException(ForecastIndex.CONFIG.getIndexName())); + } + } + + public static void onGetAdResponse( + GetResponse response, + User requestUser, + String forecastId, + ActionListener listener, + Consumer function, + NamedXContentRegistry xContentRegistry, + boolean filterByBackendRole + ) { + if (response.isExists()) { + try ( + XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Forecaster forecaster = Forecaster.parse(parser); + User resourceUser = forecaster.getUser(); + + if (!filterByBackendRole || checkUserPermissions(requestUser, resourceUser, forecastId)) { + function.accept(forecaster); + } else { + logger.debug("User: " + requestUser.getName() + " does not have permissions to access forecaster: " + forecastId); + listener.onFailure(new TimeSeriesException(NO_PERMISSION_TO_ACCESS_FORECASTER + forecastId)); + } + } catch (Exception e) { + listener.onFailure(new TimeSeriesException(FAIL_TO_GET_USER_INFO + forecastId)); + } + } else { + listener.onFailure(new ResourceNotFoundException(forecastId, FAIL_TO_FIND_FORECASTER_MSG + forecastId)); + } + } + + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String forecasterId) throws Exception { + if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { + return false; + } + // Check if requested user has backend role required to access the resource + for (String backendRole : requestedUser.getBackendRoles()) { + if (resourceUser.getBackendRoles().contains(backendRole)) { + logger + .debug( + "User: " + + requestedUser.getName() + + " has backend role: " + + backendRole + + " permissions to access forecaster: " + + forecasterId + ); + return true; + } + } + return false; + } + + public static User getUserContext(Client client) { + String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + logger.debug("Filtering result by " + userStr); + return User.parse(userStr); + } +} diff --git a/src/main/java/org/opensearch/ad/DetectorModelSize.java b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java similarity index 74% rename from src/main/java/org/opensearch/ad/DetectorModelSize.java rename to src/main/java/org/opensearch/timeseries/AnalysisModelSize.java index 52e4660e6..5e70c456c 100644 --- a/src/main/java/org/opensearch/ad/DetectorModelSize.java +++ b/src/main/java/org/opensearch/timeseries/AnalysisModelSize.java @@ -9,16 +9,16 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.util.Map; -public interface DetectorModelSize { +public interface AnalysisModelSize { /** * Gets all of a detector's model sizes hosted on a node * - * @param detectorId Detector Id + * @param id Analysis Id * @return a map of model id to its memory size */ - Map getModelSize(String detectorId); + Map getModelSize(String id); } diff --git a/src/main/java/org/opensearch/timeseries/AnalysisType.java b/src/main/java/org/opensearch/timeseries/AnalysisType.java new file mode 100644 index 000000000..7d7cc805e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/AnalysisType.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +public enum AnalysisType { + AD, + FORECAST +} diff --git a/src/main/java/org/opensearch/ad/CleanState.java b/src/main/java/org/opensearch/timeseries/CleanState.java similarity index 94% rename from src/main/java/org/opensearch/ad/CleanState.java rename to src/main/java/org/opensearch/timeseries/CleanState.java index ae8085e88..fac03b453 100644 --- a/src/main/java/org/opensearch/ad/CleanState.java +++ b/src/main/java/org/opensearch/timeseries/CleanState.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; /** * Represent a state organized via detectorId. When deleting a detector's state, diff --git a/src/main/java/org/opensearch/timeseries/ExceptionRecorder.java b/src/main/java/org/opensearch/timeseries/ExceptionRecorder.java new file mode 100644 index 000000000..5b692e96f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ExceptionRecorder.java @@ -0,0 +1,20 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import java.util.Optional; + +public interface ExceptionRecorder { + public void setException(String id, Exception e); + + public Optional fetchExceptionAndClear(String id); +} diff --git a/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java new file mode 100644 index 000000000..ea6f624f2 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ExecuteResultResponseRecorder.java @@ -0,0 +1,377 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import java.time.Instant; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.model.DetectorProfileName; +import org.opensearch.ad.transport.ProfileAction; +import org.opensearch.ad.transport.ProfileRequest; +import org.opensearch.ad.transport.RCFPollingAction; +import org.opensearch.ad.transport.RCFPollingRequest; +import org.opensearch.client.Client; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.commons.authuser.User; +import org.opensearch.search.SearchHits; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ExceptionUtil; + +public abstract class ExecuteResultResponseRecorder< + IndexType extends Enum & TimeSeriesIndex, + IndexManagementType extends IndexManagement, + TaskCacheManagerType extends TaskCacheManager, + TaskTypeEnum extends TaskType, + TaskClass extends TimeSeriesTask, + TaskManagerType extends TaskManager, + IndexableResultType extends IndexableResult + > { + + private static final Logger log = LogManager.getLogger(ExecuteResultResponseRecorder.class); + + protected IndexManagementType indexManagement; + private ResultBulkIndexingHandler resultHandler; + protected TaskManagerType taskManager; + private DiscoveryNodeFilterer nodeFilter; + private ThreadPool threadPool; + private String threadPoolName; + private Client client; + private NodeStateManager nodeStateManager; + private TaskCacheManager taskCacheManager; + private int rcfMinSamples; + protected IndexType resultIndex; + private AnalysisType analysisType; + + public ExecuteResultResponseRecorder( + IndexManagementType indexManagement, + ResultBulkIndexingHandler resultHandler, + TaskManagerType taskManager, + DiscoveryNodeFilterer nodeFilter, + ThreadPool threadPool, + String threadPoolName, + Client client, + NodeStateManager nodeStateManager, + TaskCacheManager taskCacheManager, + int rcfMinSamples, + IndexType resultIndex, + AnalysisType analysisType + ) { + this.indexManagement = indexManagement; + this.resultHandler = resultHandler; + this.taskManager = taskManager; + this.nodeFilter = nodeFilter; + this.threadPool = threadPool; + this.threadPoolName = threadPoolName; + this.client = client; + this.nodeStateManager = nodeStateManager; + this.taskCacheManager = taskCacheManager; + this.rcfMinSamples = rcfMinSamples; + this.resultIndex = resultIndex; + this.analysisType = analysisType; + } + + public void indexResult( + Instant detectionStartTime, + Instant executionStartTime, + //ResultResponseType response, + ResultResponse response, + Config config + ) { + String configId = config.getId(); + try { + + if (!response.shouldSave()) { + updateRealtimeTask(response, configId); + return; + } + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + if (response.getError() != null) { + log.info("Result action run successfully for {} with error {}", configId, response.getError()); + } + + List analysisResults = response + .toIndexableResults( + configId, + dataStartTime, + dataEndTime, + executionStartTime, + Instant.now(), + indexManagement.getSchemaVersion(resultIndex), + user, + response.getError() + ); + + String resultIndex = config.getCustomResultIndex(); + resultHandler + .bulk( + resultIndex, + analysisResults, + configId, + ActionListener + .wrap( + r -> {}, + exception -> log.error(String.format(Locale.ROOT, "Fail to bulk for %s", configId), exception) + ) + ); + updateRealtimeTask(response, configId); + } catch (EndRunException e) { + throw e; + } catch (Exception e) { + log.error("Failed to index result for " + configId, e); + } + } + + /** + * + * If result action is handled asynchronously, response won't contain the result. + * This function wait some time before fetching update. + * + * @param response response returned from executing AnomalyResultAction + * @param configId config Id + */ + protected void delayedUpdate(ResultResponse response, String configId) { + DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes(); + Set profiles = new HashSet<>(); + profiles.add(DetectorProfileName.INIT_PROGRESS); + ProfileRequest profileRequest = new ProfileRequest(configId, profiles, true, dataNodes); + Runnable profileHCInitProgress = () -> { + // TODO: change to use customized profile action. Now it is limited to AD profile. + client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> { + log.debug("Update latest realtime task for config {}, total updates: {}", configId, r.getTotalUpdates()); + updateLatestRealtimeTask( + configId, + null, + r.getTotalUpdates(), + response.getConfigIntervalInMinutes(), + response.getError() + ); + }, e -> { log.error("Failed to update latest realtime task for " + configId, e); })); + }; + if (!taskManager.isHCRealtimeTaskStartInitializing(configId)) { + // real time init progress is 0 may mean this is a newly started detector + // Delay real time cache update by one minute. If we are in init status, the delay may give the model training time to + // finish. We can change the detector running immediately instead of waiting for the next interval. + threadPool + .schedule(profileHCInitProgress, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + profileHCInitProgress.run(); + } + } + + protected void updateLatestRealtimeTask( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error + ) { + // Don't need info as this will be printed repeatedly in each interval + ActionListener listener = ActionListener.wrap(r -> { + if (r != null) { + log.debug("Updated latest realtime task successfully for config {}, taskState: {}", configId, taskState); + } + }, e -> { + if ((e instanceof ResourceNotFoundException) && e.getMessage().contains(CommonMessages.CAN_NOT_FIND_LATEST_TASK)) { + // Clear realtime task cache, will recreate task in next run, check ADResultProcessor. + log.error("Can't find latest realtime task of config " + configId); + taskManager.removeRealtimeTaskCache(configId); + } else { + log.error("Failed to update latest realtime task for config " + configId, e); + } + }); + + // rcfTotalUpdates is null when we save exception messages + if (!taskCacheManager.hasQueriedResultIndex(configId) && rcfTotalUpdates != null && rcfTotalUpdates < rcfMinSamples) { + // confirm the total updates number since it is possible that we have already had results after job enabling time + // If yes, total updates should be at least rcfMinSamples so that the init progress reaches 100%. + confirmTotalRCFUpdatesFound( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + ActionListener + .wrap( + r -> taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, r, configIntervalInMinutes, error, listener), + e -> { + log.error("Fail to confirm rcf update", e); + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode( + configId, + taskState, + rcfTotalUpdates, + configIntervalInMinutes, + error, + listener + ); + } + ) + ); + } else { + taskManager + .updateLatestRealtimeTaskOnCoordinatingNode(configId, taskState, rcfTotalUpdates, configIntervalInMinutes, error, listener); + } + } + + /** + * The function is not only indexing the result with the exception, but also updating the task state after + * 60s if the exception is related to cold start (index not found exceptions) for a single stream detector. + * + * @param executeStartTime execution start time + * @param executeEndTime execution end time + * @param errorMessage Error message to record + * @param taskState task state (e.g., stopped) + * @param config config accessor + */ + public void indexResultException( + Instant executeStartTime, + Instant executeEndTime, + String errorMessage, + String taskState, + Config config + ) { + String configId = config.getId(); + try { + IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) config.getWindowDelay(); + Instant dataStartTime = executeStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + Instant dataEndTime = executeEndTime.minus(windowDelay.getInterval(), windowDelay.getUnit()); + User user = config.getUser(); + + IndexableResultType resultToSave = createErrorResult(configId, dataStartTime, dataEndTime, executeEndTime, errorMessage, user); + String resultIndex = config.getCustomResultIndex(); + if (resultIndex != null && !indexManagement.doesIndexExist(resultIndex)) { + // Set result index as null, will write exception to default result index. + resultHandler.index(resultToSave, configId, null); + } else { + resultHandler.index(resultToSave, configId, resultIndex); + } + + if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !config.isHighCardinality()) { + // single stream detector raises ResourceNotFoundException containing ADCommonMessages.NO_CHECKPOINT_ERR_MSG + // when there is no checkpoint. + // Delay real time cache update by one minute so we will have trained models by then and update the state + // document accordingly. + threadPool.schedule(() -> { + RCFPollingRequest request = new RCFPollingRequest(configId); + client.execute(RCFPollingAction.INSTANCE, request, ActionListener.wrap(rcfPollResponse -> { + long totalUpdates = rcfPollResponse.getTotalUpdates(); + // if there are updates, don't record failures + updateLatestRealtimeTask( + configId, + taskState, + totalUpdates, + config.getIntervalInMinutes(), + totalUpdates > 0 ? "" : errorMessage + ); + }, e -> { + log.error("Fail to execute RCFRollingAction", e); + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + })); + }, new TimeValue(60, TimeUnit.SECONDS), threadPoolName); + } else { + updateLatestRealtimeTask(configId, taskState, null, null, errorMessage); + } + + } catch (Exception e) { + log.error("Failed to index anomaly result for " + configId, e); + } + } + + private void confirmTotalRCFUpdatesFound( + String configId, + String taskState, + Long rcfTotalUpdates, + Long configIntervalInMinutes, + String error, + ActionListener listener + ) { + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get detector")); + return; + } + nodeStateManager.getJob(configId, ActionListener.wrap(jobOptional -> { + if (!jobOptional.isPresent()) { + listener.onFailure(new TimeSeriesException(configId, "fail to get job")); + return; + } + + ProfileUtil + .confirmRealtimeInitStatus( + configOptional.get(), + jobOptional.get().getEnabledTime().toEpochMilli(), + client, + analysisType, + ActionListener.wrap(searchResponse -> { + ActionListener.completeWith(listener, () -> { + SearchHits hits = searchResponse.getHits(); + Long correctedTotalUpdates = rcfTotalUpdates; + if (hits.getTotalHits().value > 0L) { + // correct the number if we have already had results after job enabling time + // so that the detector won't stay initialized + correctedTotalUpdates = Long.valueOf(rcfMinSamples); + } + taskCacheManager.markResultIndexQueried(configId); + return correctedTotalUpdates; + }); + }, exception -> { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + // anomaly result index is not created yet + taskCacheManager.markResultIndexQueried(configId); + listener.onResponse(0L); + } else { + listener.onFailure(exception); + } + }) + ); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get job")))); + }, e -> listener.onFailure(new TimeSeriesException(configId, "fail to get config")))); + } + + protected abstract IndexableResultType createErrorResult( + String configId, + Instant dataStartTime, + Instant dataEndTime, + Instant executeEndTime, + String errorMessage, + User user + ); + + //protected abstract void updateRealtimeTask(ResultResponseType response, String configId); + protected abstract void updateRealtimeTask(ResultResponse response, String configId); +} diff --git a/src/main/java/org/opensearch/ad/ExpiringState.java b/src/main/java/org/opensearch/timeseries/ExpiringState.java similarity index 94% rename from src/main/java/org/opensearch/ad/ExpiringState.java rename to src/main/java/org/opensearch/timeseries/ExpiringState.java index 0df0e1f51..f5e6d3669 100644 --- a/src/main/java/org/opensearch/ad/ExpiringState.java +++ b/src/main/java/org/opensearch/timeseries/ExpiringState.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.time.Duration; import java.time.Instant; diff --git a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java b/src/main/java/org/opensearch/timeseries/JobProcessor.java similarity index 64% rename from src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java rename to src/main/java/org/opensearch/timeseries/JobProcessor.java index 7c2427847..4c750172c 100644 --- a/src/main/java/org/opensearch/ad/AnomalyDetectorJobRunner.java +++ b/src/main/java/org/opensearch/timeseries/JobProcessor.java @@ -9,12 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import static org.opensearch.action.DocWriteResponse.Result.CREATED; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME; import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import java.io.IOException; @@ -27,22 +26,14 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionType; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.ADTaskState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyResultAction; -import org.opensearch.ad.transport.AnomalyResultRequest; -import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.AnomalyResultTransportAction; -import org.opensearch.ad.util.SecurityUtil; import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; @@ -53,8 +44,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.jobscheduler.spi.JobExecutionContext; import org.opensearch.jobscheduler.spi.LockModel; -import org.opensearch.jobscheduler.spi.ScheduledJobParameter; -import org.opensearch.jobscheduler.spi.ScheduledJobRunner; import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.jobscheduler.spi.utils.LockService; import org.opensearch.threadpool.ThreadPool; @@ -63,41 +52,52 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.util.SecurityUtil; import com.google.common.base.Throwables; /** - * JobScheduler will call AD job runner to get anomaly result periodically + * JobScheduler will call job runner to get time series analysis result periodically */ -public class AnomalyDetectorJobRunner implements ScheduledJobRunner { - private static final Logger log = LogManager.getLogger(AnomalyDetectorJobRunner.class); - private static AnomalyDetectorJobRunner INSTANCE; +public abstract class JobProcessor & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, TaskManagerType extends TaskManager, IndexableResultType extends IndexableResult, ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder> { + + private static final Logger log = LogManager.getLogger(JobProcessor.class); + private Settings settings; private int maxRetryForEndRunException; private Client client; private ThreadPool threadPool; - private ConcurrentHashMap detectorEndRunExceptionCount; - private ADIndexManagement anomalyDetectionIndices; - private ADTaskManager adTaskManager; + private ConcurrentHashMap endRunExceptionCount; + private IndexManagementType indexManagement; + private TaskManagerType taskManager; private NodeStateManager nodeStateManager; - private ExecuteADResultResponseRecorder recorder; - - public static AnomalyDetectorJobRunner getJobRunnerInstance() { - if (INSTANCE != null) { - return INSTANCE; - } - synchronized (AnomalyDetectorJobRunner.class) { - if (INSTANCE != null) { - return INSTANCE; - } - INSTANCE = new AnomalyDetectorJobRunner(); - return INSTANCE; - } - } - - private AnomalyDetectorJobRunner() { + private ExecuteResultResponseRecorderType recorder; + private AnalysisType analysisType; + private String threadPoolName; + private ActionType> resultAction; + + protected JobProcessor( + AnalysisType analysisType, + String threadPoolName, + ActionType> resultAction + ) { // Singleton class, use getJobRunnerInstance method instead of constructor - this.detectorEndRunExceptionCount = new ConcurrentHashMap<>(); + this.endRunExceptionCount = new ConcurrentHashMap<>(); + this.analysisType = analysisType; + this.threadPoolName = threadPoolName; + this.resultAction = resultAction; } public void setClient(Client client) { @@ -108,52 +108,48 @@ public void setThreadPool(ThreadPool threadPool) { this.threadPool = threadPool; } - public void setSettings(Settings settings) { + protected void registerSettings(Settings settings, Setting maxRetryForEndRunExceptionSetting) { this.settings = settings; - this.maxRetryForEndRunException = AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings); + this.maxRetryForEndRunException = maxRetryForEndRunExceptionSetting.get(settings); } - public void setAdTaskManager(ADTaskManager adTaskManager) { - this.adTaskManager = adTaskManager; + public void setTaskManager(TaskManagerType adTaskManager) { + this.taskManager = adTaskManager; } - public void setAnomalyDetectionIndices(ADIndexManagement anomalyDetectionIndices) { - this.anomalyDetectionIndices = anomalyDetectionIndices; + public void setIndexManagement(IndexManagementType anomalyDetectionIndices) { + this.indexManagement = anomalyDetectionIndices; } public void setNodeStateManager(NodeStateManager nodeStateManager) { this.nodeStateManager = nodeStateManager; } - public void setExecuteADResultResponseRecorder(ExecuteADResultResponseRecorder recorder) { + public void setExecuteResultResponseRecorder(ExecuteResultResponseRecorderType recorder) { this.recorder = recorder; } - @Override - public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { - String detectorId = scheduledJobParameter.getName(); - log.info("Start to run AD job {}", detectorId); - adTaskManager.refreshRealtimeJobRunTime(detectorId); - if (!(scheduledJobParameter instanceof AnomalyDetectorJob)) { - throw new IllegalArgumentException( - "Job parameter is not instance of AnomalyDetectorJob, type: " + scheduledJobParameter.getClass().getCanonicalName() - ); - } - AnomalyDetectorJob jobParameter = (AnomalyDetectorJob) scheduledJobParameter; + public void process(Job jobParameter, JobExecutionContext context) { + String configId = jobParameter.getName(); + + log.info("Start to run {} job {}", analysisType, configId); + + taskManager.refreshRealtimeJobRunTime(configId); + Instant executionStartTime = Instant.now(); IntervalSchedule schedule = (IntervalSchedule) jobParameter.getSchedule(); - Instant detectionStartTime = executionStartTime.minus(schedule.getInterval(), schedule.getUnit()); + Instant analysisStartTime = executionStartTime.minus(schedule.getInterval(), schedule.getUnit()); final LockService lockService = context.getLockService(); Runnable runnable = () -> { try { - nodeStateManager.getAnomalyDetector(detectorId, ActionListener.wrap(detectorOptional -> { - if (!detectorOptional.isPresent()) { - log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId)); + nodeStateManager.getConfig(configId, analysisType, ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + log.error(new ParameterizedMessage("fail to get config [{}]", configId)); return; } - AnomalyDetector detector = detectorOptional.get(); + Config config = configOptional.get(); if (jobParameter.getLockDurationSeconds() != null) { lockService @@ -162,84 +158,84 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont context, ActionListener .wrap( - lock -> runAdJob( + lock -> runJob( jobParameter, lockService, lock, - detectionStartTime, + analysisStartTime, executionStartTime, recorder, - detector + config ), exception -> { - indexAnomalyResultException( + indexResultException( jobParameter, lockService, null, - detectionStartTime, + analysisStartTime, executionStartTime, exception, false, recorder, - detector + config ); - throw new IllegalStateException("Failed to acquire lock for AD job: " + detectorId); + throw new IllegalStateException("Failed to acquire lock for job: " + configId); } ) ); } else { - log.warn("Can't get lock for AD job: " + detectorId); + log.warn("Can't get lock for job: " + configId); } - }, e -> log.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), e))); + }, e -> log.error(new ParameterizedMessage("fail to get config [{}]", configId), e))); } catch (Exception e) { // os log won't show anything if there is an exception happens (maybe due to running on a ExecutorService) // we at least log the error. - log.error("Can't start AD job: " + detectorId, e); + log.error("Can't start job: " + configId, e); throw e; } }; - ExecutorService executor = threadPool.executor(AD_THREAD_POOL_NAME); + ExecutorService executor = threadPool.executor(threadPoolName); executor.submit(runnable); } /** - * Get anomaly result, index result or handle exception if failed. + * Get analysis result, index result or handle exception if failed. * * @param jobParameter scheduled job parameter * @param lockService lock service * @param lock lock to run job - * @param detectionStartTime detection start time - * @param executionStartTime detection end time + * @param analysisStartTime analysis start time + * @param analysisEndTime detection end time * @param recorder utility to record job execution result * @param detector associated detector accessor */ - protected void runAdJob( - AnomalyDetectorJob jobParameter, + protected void runJob( + Job jobParameter, LockService lockService, LockModel lock, - Instant detectionStartTime, - Instant executionStartTime, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + Instant analysisStartTime, + Instant analysisEndTime, + ExecuteResultResponseRecorderType recorder, + Config detector ) { String detectorId = jobParameter.getName(); if (lock == null) { - indexAnomalyResultException( + indexResultException( jobParameter, lockService, lock, - detectionStartTime, - executionStartTime, - "Can't run AD job due to null lock", + analysisStartTime, + analysisEndTime, + "Can't run job due to null lock", false, recorder, detector ); return; } - anomalyDetectionIndices.update(); + indexManagement.update(); User userInfo = SecurityUtil.getUserFromJob(jobParameter, settings); @@ -248,71 +244,45 @@ protected void runAdJob( String resultIndex = jobParameter.getCustomResultIndex(); if (resultIndex == null) { - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); + runJob(jobParameter, lockService, lock, analysisStartTime, analysisEndTime, detectorId, user, roles, recorder, detector); return; } ActionListener listener = ActionListener.wrap(r -> { log.debug("Custom index is valid"); }, e -> { Exception exception = new EndRunException(detectorId, e.getMessage(), true); - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, exception, recorder, detector); + handleException(jobParameter, lockService, lock, analysisStartTime, analysisEndTime, exception, recorder, detector); }); - anomalyDetectionIndices.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { + indexManagement.validateCustomIndexForBackendJob(resultIndex, detectorId, user, roles, () -> { listener.onResponse(true); - runAnomalyDetectionJob( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - detectorId, - user, - roles, - recorder, - detector - ); + runJob(jobParameter, lockService, lock, analysisStartTime, analysisEndTime, detectorId, user, roles, recorder, detector); }, listener); } - private void runAnomalyDetectionJob( - AnomalyDetectorJob jobParameter, + private void runJob( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, - String detectorId, + String configId, String user, List roles, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config detector ) { // using one thread in the write threadpool - try (InjectSecurity injectSecurity = new InjectSecurity(detectorId, settings, client.threadPool().getThreadContext())) { + try (InjectSecurity injectSecurity = new InjectSecurity(configId, settings, client.threadPool().getThreadContext())) { // Injecting user role to verify if the user has permissions for our API. injectSecurity.inject(user, roles); - AnomalyResultRequest request = new AnomalyResultRequest( - detectorId, - detectionStartTime.toEpochMilli(), - executionStartTime.toEpochMilli() - ); + ResultRequest request = createResultRequest(configId, detectionStartTime.toEpochMilli(), executionStartTime.toEpochMilli()); client .execute( - AnomalyResultAction.INSTANCE, + resultAction, request, ActionListener .wrap( response -> { - indexAnomalyResult( + indexResult( jobParameter, lockService, lock, @@ -324,7 +294,7 @@ private void runAnomalyDetectionJob( ); }, exception -> { - handleAdException( + handleException( jobParameter, lockService, lock, @@ -338,18 +308,8 @@ private void runAnomalyDetectionJob( ) ); } catch (Exception e) { - indexAnomalyResultException( - jobParameter, - lockService, - lock, - detectionStartTime, - executionStartTime, - e, - true, - recorder, - detector - ); - log.error("Failed to execute AD job " + detectorId, e); + indexResultException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, true, recorder, detector); + log.error("Failed to execute AD job " + configId, e); } } @@ -390,17 +350,17 @@ private void runAnomalyDetectionJob( * @param executionStartTime detection end time * @param exception exception * @param recorder utility to record job execution result - * @param detector associated detector accessor + * @param config associated config accessor */ - protected void handleAdException( - AnomalyDetectorJob jobParameter, + protected void handleException( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, Exception exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config config ) { String detectorId = jobParameter.getName(); if (exception instanceof EndRunException) { @@ -409,7 +369,7 @@ protected void handleAdException( if (((EndRunException) exception).isEndNow()) { // Stop AD job if EndRunException shows we should end job now. log.info("JobRunner will stop AD job due to EndRunException for {}", detectorId); - stopAdJobForEndRunException( + stopJobForEndRunException( jobParameter, lockService, lock, @@ -417,10 +377,10 @@ protected void handleAdException( executionStartTime, (EndRunException) exception, recorder, - detector + config ); } else { - detectorEndRunExceptionCount.compute(detectorId, (k, v) -> { + endRunExceptionCount.compute(detectorId, (k, v) -> { if (v == null) { return 1; } else { @@ -429,14 +389,14 @@ protected void handleAdException( }); log.info("EndRunException happened for {}", detectorId); // if AD job failed consecutively due to EndRunException and failed times exceeds upper limit, will stop AD job - if (detectorEndRunExceptionCount.get(detectorId) > maxRetryForEndRunException) { + if (endRunExceptionCount.get(detectorId) > maxRetryForEndRunException) { log .info( "JobRunner will stop AD job due to EndRunException retry exceeds upper limit {} for {}", maxRetryForEndRunException, detectorId ); - stopAdJobForEndRunException( + stopJobForEndRunException( jobParameter, lockService, lock, @@ -444,11 +404,11 @@ protected void handleAdException( executionStartTime, (EndRunException) exception, recorder, - detector + config ); return; } - indexAnomalyResultException( + indexResultException( jobParameter, lockService, lock, @@ -457,17 +417,17 @@ protected void handleAdException( exception.getMessage(), true, recorder, - detector + config ); } } else { - detectorEndRunExceptionCount.remove(detectorId); + endRunExceptionCount.remove(detectorId); if (exception instanceof InternalFailure) { log.error("InternalFailure happened when executing anomaly result action for " + detectorId, exception); } else { log.error("Failed to execute anomaly result action for " + detectorId, exception); } - indexAnomalyResultException( + indexResultException( jobParameter, lockService, lock, @@ -476,30 +436,30 @@ protected void handleAdException( exception, true, recorder, - detector + config ); } } - private void stopAdJobForEndRunException( - AnomalyDetectorJob jobParameter, + private void stopJobForEndRunException( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, EndRunException exception, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config config ) { - String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); + String configId = jobParameter.getName(); + endRunExceptionCount.remove(configId); String errorPrefix = exception.isEndNow() - ? "Stopped detector: " - : "Stopped detector as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; + ? "Stopped analysis: " + : "Stopped analysis as job failed consecutively for more than " + this.maxRetryForEndRunException + " times: "; String error = errorPrefix + exception.getMessage(); - stopAdJob( - detectorId, - () -> indexAnomalyResultException( + stopJob( + configId, + () -> indexResultException( jobParameter, lockService, lock, @@ -507,14 +467,14 @@ private void stopAdJobForEndRunException( executionStartTime, error, true, - ADTaskState.STOPPED.name(), + TaskState.STOPPED.name(), recorder, - detector + config ) ); } - private void stopAdJob(String detectorId, ExecutorFunction function) { + private void stopJob(String detectorId, ExecutorFunction function) { GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(detectorId); ActionListener listener = ActionListener.wrap(response -> { if (response.isExists()) { @@ -524,9 +484,9 @@ private void stopAdJob(String detectorId, ExecutorFunction function) { .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + Job job = Job.parse(parser); if (job.isEnabled()) { - AnomalyDetectorJob newJob = new AnomalyDetectorJob( + Job newJob = new Job( job.getName(), job.getSchedule(), job.getWindowDelay(), @@ -536,7 +496,8 @@ private void stopAdJob(String detectorId, ExecutorFunction function) { Instant.now(), job.getLockDurationSeconds(), job.getUser(), - job.getCustomResultIndex() + job.getCustomResultIndex(), + job.getAnalysisType() ); IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) @@ -545,14 +506,14 @@ private void stopAdJob(String detectorId, ExecutorFunction function) { client.index(indexRequest, ActionListener.wrap(indexResponse -> { if (indexResponse != null && (indexResponse.getResult() == CREATED || indexResponse.getResult() == UPDATED)) { - log.info("AD Job was disabled by JobRunner for " + detectorId); + log.info("Job was disabled by JobRunner for " + detectorId); // function.execute(); } else { - log.warn("Failed to disable AD job for " + detectorId); + log.warn("Failed to disable job for " + detectorId); } - }, exception -> { log.error("JobRunner failed to update AD job as disabled for " + detectorId, exception); })); + }, exception -> { log.error("JobRunner failed to update job as disabled for " + detectorId, exception); })); } else { - log.info("AD Job was disabled for " + detectorId); + log.info("Job was disabled for " + detectorId); } } catch (IOException e) { log.error("JobRunner failed to stop detector job " + detectorId, e); @@ -565,22 +526,22 @@ private void stopAdJob(String detectorId, ExecutorFunction function) { client.get(getRequest, ActionListener.runAfter(listener, () -> function.execute())); } - private void indexAnomalyResult( - AnomalyDetectorJob jobParameter, + private void indexResult( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, - AnomalyResultResponse response, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ResultResponse response, + ExecuteResultResponseRecorderType recorder, + Config detector ) { String detectorId = jobParameter.getName(); - detectorEndRunExceptionCount.remove(detectorId); + endRunExceptionCount.remove(detectorId); try { - recorder.indexAnomalyResult(detectionStartTime, executionStartTime, response, detector); + recorder.indexResult(detectionStartTime, executionStartTime, response, detector); } catch (EndRunException e) { - handleAdException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); + handleException(jobParameter, lockService, lock, detectionStartTime, executionStartTime, e, recorder, detector); } catch (Exception e) { log.error("Failed to index anomaly result for " + detectorId, e); } finally { @@ -589,22 +550,22 @@ private void indexAnomalyResult( } - private void indexAnomalyResultException( - AnomalyDetectorJob jobParameter, + private void indexResultException( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, Exception exception, boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config detector ) { try { String errorMessage = exception instanceof TimeSeriesException ? exception.getMessage() : Throwables.getStackTraceAsString(exception); - indexAnomalyResultException( + indexResultException( jobParameter, lockService, lock, @@ -616,22 +577,22 @@ private void indexAnomalyResultException( detector ); } catch (Exception e) { - log.error("Failed to index anomaly result for " + jobParameter.getName(), e); + log.error("Failed to index result for " + jobParameter.getName(), e); } } - private void indexAnomalyResultException( - AnomalyDetectorJob jobParameter, + private void indexResultException( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, Instant executionStartTime, String errorMessage, boolean releaseLock, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config detector ) { - indexAnomalyResultException( + indexResultException( jobParameter, lockService, lock, @@ -645,8 +606,8 @@ private void indexAnomalyResultException( ); } - private void indexAnomalyResultException( - AnomalyDetectorJob jobParameter, + private void indexResultException( + Job jobParameter, LockService lockService, LockModel lock, Instant detectionStartTime, @@ -654,11 +615,11 @@ private void indexAnomalyResultException( String errorMessage, boolean releaseLock, String taskState, - ExecuteADResultResponseRecorder recorder, - AnomalyDetector detector + ExecuteResultResponseRecorderType recorder, + Config detector ) { try { - recorder.indexAnomalyResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); + recorder.indexResultException(detectionStartTime, executionStartTime, errorMessage, taskState, detector); } finally { if (releaseLock) { releaseLock(jobParameter, lockService, lock); @@ -666,15 +627,27 @@ private void indexAnomalyResultException( } } - private void releaseLock(AnomalyDetectorJob jobParameter, LockService lockService, LockModel lock) { + private void releaseLock(Job jobParameter, LockService lockService, LockModel lock) { lockService .release( lock, ActionListener .wrap( - released -> { log.info("Released lock for AD job {}", jobParameter.getName()); }, - exception -> { log.error("Failed to release lock for AD job: " + jobParameter.getName(), exception); } + released -> { log.info("Released lock for {} job {}", analysisType, jobParameter.getName()); }, + exception -> { + log + .error( + new ParameterizedMessage( + "Failed to release lock for [{}] job [{}]", + analysisType, + jobParameter.getName() + ), + exception + ); + } ) ); } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); } diff --git a/src/main/java/org/opensearch/timeseries/JobRunner.java b/src/main/java/org/opensearch/timeseries/JobRunner.java new file mode 100644 index 000000000..68a50ee4f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/JobRunner.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries; + +import org.opensearch.ad.ADJobProcessor; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.jobscheduler.spi.JobExecutionContext; +import org.opensearch.jobscheduler.spi.ScheduledJobParameter; +import org.opensearch.jobscheduler.spi.ScheduledJobRunner; +import org.opensearch.timeseries.model.Job; + +public class JobRunner implements ScheduledJobRunner { + private static JobRunner INSTANCE; + + public static JobRunner getJobRunnerInstance() { + if (INSTANCE != null) { + return INSTANCE; + } + synchronized (JobRunner.class) { + if (INSTANCE != null) { + return INSTANCE; + } + INSTANCE = new JobRunner(); + return INSTANCE; + } + } + + @Override + public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionContext context) { + if (!(scheduledJobParameter instanceof Job)) { + throw new IllegalArgumentException( + "Job parameter is not instance of Job, type: " + scheduledJobParameter.getClass().getCanonicalName() + ); + } + Job jobParameter = (Job) scheduledJobParameter; + switch (jobParameter.getAnalysisType()) { + case AD: + ADJobProcessor.getInstance().process(jobParameter, context); + break; + case FORECAST: + ForecastJobProcessor.getInstance().process(jobParameter, context); + break; + default: + throw new IllegalArgumentException("Analysis type is not supported, type: : " + jobParameter.getAnalysisType()); + } + } +} diff --git a/src/main/java/org/opensearch/ad/MaintenanceState.java b/src/main/java/org/opensearch/timeseries/MaintenanceState.java similarity index 96% rename from src/main/java/org/opensearch/ad/MaintenanceState.java rename to src/main/java/org/opensearch/timeseries/MaintenanceState.java index 646715f7a..07bbb9546 100644 --- a/src/main/java/org/opensearch/ad/MaintenanceState.java +++ b/src/main/java/org/opensearch/timeseries/MaintenanceState.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.time.Duration; import java.util.Map; diff --git a/src/main/java/org/opensearch/ad/MemoryTracker.java b/src/main/java/org/opensearch/timeseries/MemoryTracker.java similarity index 89% rename from src/main/java/org/opensearch/ad/MemoryTracker.java rename to src/main/java/org/opensearch/timeseries/MemoryTracker.java index 1e40ef47a..a474ae21e 100644 --- a/src/main/java/org/opensearch/ad/MemoryTracker.java +++ b/src/main/java/org/opensearch/timeseries/MemoryTracker.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE; import java.util.EnumMap; import java.util.Locale; @@ -19,55 +19,48 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.monitor.jvm.JvmService; +import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import com.amazon.randomcutforest.RandomCutForest; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; -/** - * Class to track AD memory usage. - * - */ public class MemoryTracker { private static final Logger LOG = LogManager.getLogger(MemoryTracker.class); public enum Origin { - SINGLE_ENTITY_DETECTOR, - HC_DETECTOR, + REAL_TIME_DETECTOR, HISTORICAL_SINGLE_ENTITY_DETECTOR, + REAL_TIME_FORECASTER } // memory tracker for total consumption of bytes - private long totalMemoryBytes; - private final Map totalMemoryBytesByOrigin; + protected long totalMemoryBytes; + protected final Map totalMemoryBytesByOrigin; // reserved for models. Cannot be deleted at will. - private long reservedMemoryBytes; - private final Map reservedMemoryBytesByOrigin; - private long heapSize; - private long heapLimitBytes; - private long desiredModelSize; + protected long reservedMemoryBytes; + protected final Map reservedMemoryBytesByOrigin; + protected long heapSize; + protected long heapLimitBytes; // we observe threshold model uses a fixed size array and the size is the same - private int thresholdModelBytes; - private ADCircuitBreakerService adCircuitBreakerService; + protected int thresholdModelBytes; + protected CircuitBreakerService timeSeriesCircuitBreakerService; /** * Constructor * * @param jvmService Service providing jvm info * @param modelMaxSizePercentage Percentage of heap for the max size of a model - * @param modelDesiredSizePercentage percentage of heap for the desired size of a model * @param clusterService Cluster service object - * @param adCircuitBreakerService Memory circuit breaker + * @param timeSeriesCircuitBreakerService Memory circuit breaker */ public MemoryTracker( JvmService jvmService, double modelMaxSizePercentage, - double modelDesiredSizePercentage, ClusterService clusterService, - ADCircuitBreakerService adCircuitBreakerService + CircuitBreakerService timeSeriesCircuitBreakerService ) { this.totalMemoryBytes = 0; this.totalMemoryBytesByOrigin = new EnumMap(Origin.class); @@ -75,40 +68,14 @@ public MemoryTracker( this.reservedMemoryBytesByOrigin = new EnumMap(Origin.class); this.heapSize = jvmService.info().getMem().getHeapMax().getBytes(); this.heapLimitBytes = (long) (heapSize * modelMaxSizePercentage); - this.desiredModelSize = (long) (heapSize * modelDesiredSizePercentage); if (clusterService != null) { clusterService .getClusterSettings() - .addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); + .addSettingsUpdateConsumer(AD_MODEL_MAX_SIZE_PERCENTAGE, it -> this.heapLimitBytes = (long) (heapSize * it)); } this.thresholdModelBytes = 180_000; - this.adCircuitBreakerService = adCircuitBreakerService; - } - - /** - * This function derives from the old code: https://tinyurl.com/2eaabja6 - * - * @param detectorId Detector Id - * @param trcf Thresholded random cut forest model - * @return true if there is enough memory; otherwise throw LimitExceededException. - */ - public synchronized boolean isHostingAllowed(String detectorId, ThresholdedRandomCutForest trcf) { - long requiredBytes = estimateTRCFModelSize(trcf); - if (canAllocateReserved(requiredBytes)) { - return true; - } else { - throw new LimitExceededException( - detectorId, - String - .format( - Locale.ROOT, - "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", - reservedMemoryBytes + requiredBytes, - heapLimitBytes - ) - ); - } + this.timeSeriesCircuitBreakerService = timeSeriesCircuitBreakerService; } /** @@ -117,7 +84,7 @@ public synchronized boolean isHostingAllowed(String detectorId, ThresholdedRando * true when circuit breaker is closed and there is enough reserved memory. */ public synchronized boolean canAllocateReserved(long requiredBytes) { - return (false == adCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); + return (false == timeSeriesCircuitBreakerService.isOpen() && reservedMemoryBytes + requiredBytes <= heapLimitBytes); } /** @@ -126,7 +93,7 @@ public synchronized boolean canAllocateReserved(long requiredBytes) { * true when circuit breaker is closed and there is enough overall memory. */ public synchronized boolean canAllocate(long bytes) { - return false == adCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; + return false == timeSeriesCircuitBreakerService.isOpen() && totalMemoryBytes + bytes <= heapLimitBytes; } public synchronized void consumeMemory(long memoryToConsume, boolean reserved, Origin origin) { @@ -159,23 +126,6 @@ private void adjustOriginMemoryRelease(long memoryToConsume, Origin origin, Map< } } - /** - * Gets the estimated size of an entity's model. - * - * @param trcf ThresholdedRandomCutForest object - * @return estimated model size in bytes - */ - public long estimateTRCFModelSize(ThresholdedRandomCutForest trcf) { - RandomCutForest forest = trcf.getForest(); - return estimateTRCFModelSize( - forest.getDimensions(), - forest.getNumberOfTrees(), - forest.getBoundingBoxCacheFraction(), - forest.getShingleSize(), - forest.isInternalShinglingEnabled() - ); - } - /** * Gets the estimated size of an entity's model. * @@ -306,14 +256,6 @@ public long getHeapLimit() { return heapLimitBytes; } - /** - * - * @return Desired model partition size in bytes - */ - public long getDesiredModelSize() { - return desiredModelSize; - } - public long getTotalMemoryBytes() { return totalMemoryBytes; } @@ -360,4 +302,46 @@ public synchronized boolean syncMemoryState(Origin origin, long totalBytes, long public int getThresholdModelBytes() { return thresholdModelBytes; } + + /** + * This function derives from the old code: https://tinyurl.com/2eaabja6 + * + * @param configId Config Id + * @param trcf Thresholded random cut forest model + * @return true if there is enough memory; otherwise throw LimitExceededException. + */ + public synchronized boolean isHostingAllowed(String configId, ThresholdedRandomCutForest trcf) { + long requiredBytes = estimateTRCFModelSize(trcf); + if (canAllocateReserved(requiredBytes)) { + return true; + } else { + throw new LimitExceededException( + configId, + String + .format( + Locale.ROOT, + "Exceeded memory limit. New size is %d bytes and max limit is %d bytes", + reservedMemoryBytes + requiredBytes, + heapLimitBytes + ) + ); + } + } + + /** + * Gets the estimated size of an entity's model. + * + * @param trcf ThresholdedRandomCutForest object + * @return estimated model size in bytes + */ + public long estimateTRCFModelSize(ThresholdedRandomCutForest trcf) { + RandomCutForest forest = trcf.getForest(); + return estimateTRCFModelSize( + forest.getDimensions(), + forest.getNumberOfTrees(), + forest.getBoundingBoxCacheFraction(), + forest.getShingleSize(), + forest.isInternalShinglingEnabled() + ); + } } diff --git a/src/main/java/org/opensearch/ad/NodeState.java b/src/main/java/org/opensearch/timeseries/NodeState.java similarity index 57% rename from src/main/java/org/opensearch/ad/NodeState.java rename to src/main/java/org/opensearch/timeseries/NodeState.java index 9c4693cbd..8537d0b64 100644 --- a/src/main/java/org/opensearch/ad/NodeState.java +++ b/src/main/java/org/opensearch/timeseries/NodeState.java @@ -9,198 +9,180 @@ * GitHub history for details. */ -package org.opensearch.ad; +package org.opensearch.timeseries; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.Optional; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; /** * Storing intermediate state during the execution of transport action * */ public class NodeState implements ExpiringState { - private String detectorId; - // detector definition - private AnomalyDetector detectorDef; - // number of partitions - private int partitonNumber; + private String configId; + // config definition + private Config configDef; // last access time private Instant lastAccessTime; - // last detection error recorded in result index. Used by DetectorStateHandler - // to check if the error for a detector has changed or not. If changed, trigger indexing. - private Optional lastDetectionError; // last error. private Optional exception; - // flag indicating whether checkpoint for the detector exists - private boolean checkPointExists; // clock to get current time private final Clock clock; + // config job + private Job configJob; + + // AD only states + // number of partitions + private int partitonNumber; + + // flag indicating whether checkpoint for the detector exists + private boolean checkPointExists; + // cold start running flag to prevent concurrent cold start private boolean coldStartRunning; - // detector job - private AnomalyDetectorJob detectorJob; - public NodeState(String detectorId, Clock clock) { - this.detectorId = detectorId; - this.detectorDef = null; - this.partitonNumber = -1; + public NodeState(String configId, Clock clock) { + this.configId = configId; + this.configDef = null; this.lastAccessTime = clock.instant(); - this.lastDetectionError = Optional.empty(); this.exception = Optional.empty(); - this.checkPointExists = false; this.clock = clock; + this.partitonNumber = -1; + this.checkPointExists = false; this.coldStartRunning = false; - this.detectorJob = null; + this.configJob = null; } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } /** * * @return Detector configuration object */ - public AnomalyDetector getDetectorDef() { + public Config getConfigDef() { refreshLastUpdateTime(); - return detectorDef; + return configDef; } /** * - * @param detectorDef Detector configuration object + * @param configDef Analysis configuration object */ - public void setDetectorDef(AnomalyDetector detectorDef) { - this.detectorDef = detectorDef; + public void setConfigDef(Config configDef) { + this.configDef = configDef; refreshLastUpdateTime(); } /** * - * @return RCF partition number of the detector + * @return last exception if any */ - public int getPartitonNumber() { + public Optional getException() { refreshLastUpdateTime(); - return partitonNumber; + return exception; } /** * - * @param partitonNumber RCF partition number + * @param exception exception to record */ - public void setPartitonNumber(int partitonNumber) { - this.partitonNumber = partitonNumber; + public void setException(Exception exception) { + this.exception = Optional.ofNullable(exception); refreshLastUpdateTime(); } /** - * Used to indicate whether cold start succeeds or not - * @return whether checkpoint of models exists or not. + * refresh last access time. */ - public boolean doesCheckpointExists() { - refreshLastUpdateTime(); - return checkPointExists; + protected void refreshLastUpdateTime() { + lastAccessTime = clock.instant(); } /** - * - * @param checkpointExists mark whether checkpoint of models exists or not. + * @param stateTtl time to leave for the state + * @return whether the transport state is expired */ - public void setCheckpointExists(boolean checkpointExists) { - refreshLastUpdateTime(); - this.checkPointExists = checkpointExists; - }; + @Override + public boolean expired(Duration stateTtl) { + return expired(lastAccessTime, stateTtl, clock.instant()); + } /** - * - * @return last model inference error - */ - public Optional getLastDetectionError() { + * + * @return RCF partition number of the detector + */ + public int getPartitonNumber() { refreshLastUpdateTime(); - return lastDetectionError; + return partitonNumber; } /** - * - * @param lastError last model inference error - */ - public void setLastDetectionError(String lastError) { - this.lastDetectionError = Optional.ofNullable(lastError); + * + * @param partitonNumber RCF partition number + */ + public void setPartitonNumber(int partitonNumber) { + this.partitonNumber = partitonNumber; refreshLastUpdateTime(); } /** - * - * @return last exception if any - */ - public Optional getException() { + * Used to indicate whether cold start succeeds or not + * @return whether checkpoint of models exists or not. + */ + public boolean doesCheckpointExists() { refreshLastUpdateTime(); - return exception; + return checkPointExists; } /** - * - * @param exception exception to record - */ - public void setException(Exception exception) { - this.exception = Optional.ofNullable(exception); + * + * @param checkpointExists mark whether checkpoint of models exists or not. + */ + public void setCheckpointExists(boolean checkpointExists) { refreshLastUpdateTime(); - } + this.checkPointExists = checkpointExists; + }; /** - * Used to prevent concurrent cold start - * @return whether cold start is running or not - */ + * Used to prevent concurrent cold start + * @return whether cold start is running or not + */ public boolean isColdStartRunning() { refreshLastUpdateTime(); return coldStartRunning; } /** - * - * @param coldStartRunning whether cold start is running or not - */ + * + * @param coldStartRunning whether cold start is running or not + */ public void setColdStartRunning(boolean coldStartRunning) { this.coldStartRunning = coldStartRunning; refreshLastUpdateTime(); } /** - * - * @return Detector configuration object - */ - public AnomalyDetectorJob getDetectorJob() { + * + * @return Job configuration object + */ + public Job getJob() { refreshLastUpdateTime(); - return detectorJob; + return configJob; } /** - * - * @param detectorJob Detector job - */ - public void setDetectorJob(AnomalyDetectorJob detectorJob) { - this.detectorJob = detectorJob; + * + * @param job analysis job + */ + public void setJob(Job job) { + this.configJob = job; refreshLastUpdateTime(); } - - /** - * refresh last access time. - */ - private void refreshLastUpdateTime() { - lastAccessTime = clock.instant(); - } - - /** - * @param stateTtl time to leave for the state - * @return whether the transport state is expired - */ - @Override - public boolean expired(Duration stateTtl) { - return expired(lastAccessTime, stateTtl, clock.instant()); - } } diff --git a/src/main/java/org/opensearch/ad/NodeStateManager.java b/src/main/java/org/opensearch/timeseries/NodeStateManager.java similarity index 59% rename from src/main/java/org/opensearch/ad/NodeStateManager.java rename to src/main/java/org/opensearch/timeseries/NodeStateManager.java index 7e3d708c2..d18b13465 100644 --- a/src/main/java/org/opensearch/ad/NodeStateManager.java +++ b/src/main/java/org/opensearch/timeseries/NodeStateManager.java @@ -1,20 +1,8 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -package org.opensearch.ad; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +package org.opensearch.timeseries; + import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.util.HashMap; @@ -22,50 +10,56 @@ import java.util.Map; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Strings; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.SingleStreamModelIdMapper; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.transport.BackPressureRouting; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; - -/** - * NodeStateManager is used to manage states shared by transport and ml components - * like AnomalyDetector object - * - */ -public class NodeStateManager implements MaintenanceState, CleanState { +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.BackPressureRouting; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class NodeStateManager implements MaintenanceState, CleanState, ExceptionRecorder { private static final Logger LOG = LogManager.getLogger(NodeStateManager.class); + public static final String NO_ERROR = "no_error"; - private ConcurrentHashMap states; - private Client client; - private NamedXContentRegistry xContentRegistry; - private ClientUtil clientUtil; + + protected ConcurrentHashMap states; + protected Client client; + protected NamedXContentRegistry xContentRegistry; + protected ClientUtil clientUtil; + protected final Clock clock; + protected final Duration stateTtl; // map from detector id to the map of ES node id to the node's backpressureMuter private Map> backpressureMuter; - private final Clock clock; - private final Duration stateTtl; private int maxRetryForUnresponsiveNode; private TimeValue mutePeriod; @@ -87,17 +81,20 @@ public NodeStateManager( ClientUtil clientUtil, Clock clock, Duration stateTtl, - ClusterService clusterService + ClusterService clusterService, + Setting maxRetryForUnresponsiveNodeSetting, + Setting backoffMinutesSetting ) { this.states = new ConcurrentHashMap<>(); this.client = client; this.xContentRegistry = xContentRegistry; this.clientUtil = clientUtil; - this.backpressureMuter = new ConcurrentHashMap<>(); this.clock = clock; this.stateTtl = stateTtl; - this.maxRetryForUnresponsiveNode = MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_RETRY_FOR_UNRESPONSIVE_NODE, it -> { + this.backpressureMuter = new ConcurrentHashMap<>(); + + this.maxRetryForUnresponsiveNode = maxRetryForUnresponsiveNodeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxRetryForUnresponsiveNodeSetting, it -> { this.maxRetryForUnresponsiveNode = it; Iterator> iter = backpressureMuter.values().iterator(); while (iter.hasNext()) { @@ -105,8 +102,8 @@ public NodeStateManager( entry.values().forEach(v -> v.setMaxRetryForUnresponsiveNode(it)); } }); - this.mutePeriod = BACKOFF_MINUTES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(BACKOFF_MINUTES, it -> { + this.mutePeriod = backoffMinutesSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(backoffMinutesSetting, it -> { this.mutePeriod = it; Iterator> iter = backpressureMuter.values().iterator(); while (iter.hasNext()) { @@ -114,117 +111,37 @@ public NodeStateManager( entry.values().forEach(v -> v.setMutePeriod(it)); } }); - } - - /** - * Get Detector config object if present - * @param adID detector Id - * @return the Detecor config object or empty Optional - */ - public Optional getAnomalyDetectorIfPresent(String adID) { - NodeState state = states.get(adID); - return Optional.ofNullable(state).map(NodeState::getDetectorDef); - } - - public void getAnomalyDetector(String adID, ActionListener> listener) { - NodeState state = states.get(adID); - if (state != null && state.getDetectorDef() != null) { - listener.onResponse(Optional.of(state.getDetectorDef())); - } else { - GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, adID); - clientUtil.asyncRequest(request, client::get, onGetDetectorResponse(adID, listener)); - } - } - - private ActionListener onGetDetectorResponse(String adID, ActionListener> listener) { - return ActionListener.wrap(response -> { - if (response == null || !response.isExists()) { - listener.onResponse(Optional.empty()); - return; - } - - String xc = response.getSourceAsString(); - LOG.debug("Fetched anomaly detector: {}", xc); - - try ( - XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) - ) { - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser, response.getId()); - // end execution if all features are disabled - if (detector.getEnabledFeatureIds().isEmpty()) { - listener.onFailure(new EndRunException(adID, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG, true).countedInStats(false)); - return; - } - NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); - state.setDetectorDef(detector); - listener.onResponse(Optional.of(detector)); - } catch (Exception t) { - LOG.error("Fail to parse detector {}", adID); - LOG.error("Stack trace:", t); - listener.onResponse(Optional.empty()); - } - }, listener::onFailure); } /** - * Get a detector's checkpoint and save a flag if we find any so that next time we don't need to do it again - * @param adID the detector's ID - * @param listener listener to handle get request + * Clean states if it is older than our stateTtl. transportState has to be a + * ConcurrentHashMap otherwise we will have + * java.util.ConcurrentModificationException. + * */ - public void getDetectorCheckpoint(String adID, ActionListener listener) { - NodeState state = states.get(adID); - if (state != null && state.doesCheckpointExists()) { - listener.onResponse(Boolean.TRUE); - return; - } - - GetRequest request = new GetRequest(ADCommonName.CHECKPOINT_INDEX_NAME, SingleStreamModelIdMapper.getRcfModelId(adID, 0)); - - clientUtil.asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener)); - } - - private ActionListener onGetCheckpointResponse(String adID, ActionListener listener) { - return ActionListener.wrap(response -> { - if (response == null || !response.isExists()) { - listener.onResponse(Boolean.FALSE); - } else { - NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); - state.setCheckpointExists(true); - listener.onResponse(Boolean.TRUE); - } - }, listener::onFailure); + @Override + public void maintenance() { + maintenance(states, stateTtl); } /** * Used in delete workflow * - * @param detectorId detector ID + * @param configId detector ID */ @Override - public void clear(String detectorId) { - Map routingMap = backpressureMuter.get(detectorId); + public void clear(String configId) { + Map routingMap = backpressureMuter.get(configId); if (routingMap != null) { routingMap.clear(); - backpressureMuter.remove(detectorId); + backpressureMuter.remove(configId); } - states.remove(detectorId); + states.remove(configId); } - /** - * Clean states if it is older than our stateTtl. transportState has to be a - * ConcurrentHashMap otherwise we will have - * java.util.ConcurrentModificationException. - * - */ - @Override - public void maintenance() { - maintenance(states, stateTtl); - } - - public boolean isMuted(String nodeId, String detectorId) { - Map routingMap = backpressureMuter.get(detectorId); + public boolean isMuted(String nodeId, String configId) { + Map routingMap = backpressureMuter.get(configId); if (routingMap == null || routingMap.isEmpty()) { return false; } @@ -235,68 +152,140 @@ public boolean isMuted(String nodeId, String detectorId) { /** * When we have a unsuccessful call with a node, increment the backpressure counter. * @param nodeId an ES node's ID - * @param detectorId Detector ID + * @param configId config ID */ - public void addPressure(String nodeId, String detectorId) { + public void addPressure(String nodeId, String configId) { Map routingMap = backpressureMuter - .computeIfAbsent(detectorId, k -> new HashMap()); + .computeIfAbsent(configId, k -> new HashMap()); routingMap.computeIfAbsent(nodeId, k -> new BackPressureRouting(k, clock, maxRetryForUnresponsiveNode, mutePeriod)).addPressure(); } /** * When we have a successful call with a node, clear the backpressure counter. * @param nodeId an ES node's ID - * @param detectorId Detector ID + * @param configId config ID */ - public void resetBackpressureCounter(String nodeId, String detectorId) { - Map routingMap = backpressureMuter.get(detectorId); + public void resetBackpressureCounter(String nodeId, String configId) { + Map routingMap = backpressureMuter.get(configId); if (routingMap == null || routingMap.isEmpty()) { - backpressureMuter.remove(detectorId); + backpressureMuter.remove(configId); return; } routingMap.remove(nodeId); } /** - * Check if there is running query on given detector - * @param detector Anomaly Detector - * @return true if given detector has a running query else false + * Get config and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param analysisType analysis type + * @param function consumer function. + * @param listener action listener. Only meant to return failure. + * @param action listener response type */ - public boolean hasRunningQuery(AnomalyDetector detector) { - return clientUtil.hasRunningQuery(detector); + public void getConfig( + String configId, + AnalysisType analysisType, + Consumer> function, + ActionListener listener + ) { + GetRequest getRequest = new GetRequest(CommonName.CONFIG_INDEX, configId); + client.get(getRequest, ActionListener.wrap(response -> { + if (!response.isExists()) { + function.accept(Optional.empty()); + return; + } + try ( + XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) + ) { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config config = null; + if (analysisType == AnalysisType.AD) { + config = AnomalyDetector.parse(parser, response.getId(), response.getVersion()); + } else if (analysisType == AnalysisType.FORECAST) { + config = Forecaster.parse(parser, response.getId(), response.getVersion()); + } else { + throw new UnsupportedOperationException("This method is not supported"); + } + + function.accept(Optional.of(config)); + } catch (Exception e) { + String message = "Failed to parse config " + configId; + LOG.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, exception -> { + LOG.error("Failed to get config " + configId, exception); + listener.onFailure(exception); + })); } - /** - * Get last error of a detector - * @param adID detector id - * @return last error for the detector - */ - public String getLastDetectionError(String adID) { - return Optional.ofNullable(states.get(adID)).flatMap(state -> state.getLastDetectionError()).orElse(NO_ERROR); + public void getConfig(String configID, AnalysisType context, ActionListener> listener) { + NodeState state = states.get(configID); + if (state != null && state.getConfigDef() != null) { + listener.onResponse(Optional.of(state.getConfigDef())); + } else { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, configID); + BiCheckedFunction configParser = context == AnalysisType.AD + ? AnomalyDetector::parse + : Forecaster::parse; + clientUtil.asyncRequest(request, client::get, onGetConfigResponse(configID, configParser, listener)); + } } - /** - * Set last detection error of a detector - * @param adID detector id - * @param error error, can be null - */ - public void setLastDetectionError(String adID, String error) { - NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); - state.setLastDetectionError(error); + private ActionListener onGetConfigResponse( + String configID, + BiCheckedFunction configParser, + ActionListener> listener + ) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Optional.empty()); + return; + } + + String xc = response.getSourceAsString(); + LOG.debug("Fetched config: {}", xc); + + try ( + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, xc) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config config = configParser.apply(parser, response.getId()); + + // end execution if all features are disabled + if (config.getEnabledFeatureIds().isEmpty()) { + listener + .onFailure(new EndRunException(configID, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG, true).countedInStats(false)); + return; + } + + NodeState state = states.computeIfAbsent(configID, configId -> new NodeState(configId, clock)); + state.setConfigDef(config); + + listener.onResponse(Optional.of(config)); + } catch (Exception t) { + LOG.error("Fail to parse config {}", configID); + LOG.error("Stack trace:", t); + listener.onResponse(Optional.empty()); + } + }, listener::onFailure); } /** - * Get a detector's exception. The method has side effect. + * Get the exception of an analysis. The method has side effect. * We reset error after calling the method because - * 1) We record a detector's exception in each interval. There is no need - * to record it twice. + * 1) We record the exception of an analysis in each interval. + * There is no need to record it twice. * 2) EndRunExceptions can stop job running. We only want to send the same * signal once for each exception. - * @param adID detector id - * @return the detector's exception + * @param configID config id + * @return the config's exception */ - public Optional fetchExceptionAndClear(String adID) { - NodeState state = states.get(adID); + @Override + public Optional fetchExceptionAndClear(String configID) { + NodeState state = states.get(configID); if (state == null) { return Optional.empty(); } @@ -307,26 +296,27 @@ public Optional fetchExceptionAndClear(String adID) { } /** - * For single-stream detector, we have one exception per interval. When + * For single-stream analysis, we have one exception per interval. When * an interval starts, it fetches and clears the exception. - * For HCAD, there can be one exception per entity. To not bloat memory + * For HC analysis, there can be one exception per entity. To not bloat memory * with exceptions, we will keep only one exception. An exception has 3 purposes: - * 1) stop detector if nothing else works; + * 1) stop analysis if nothing else works; * 2) increment error stats to ticket about high-error domain * 3) debugging. * - * For HCAD, we record all entities' exceptions in anomaly results. So 3) + * For HC analysis, we record all entities' exceptions in result index. So 3) * is covered. As long as we keep one exception among all exceptions, 2) * is covered. So the only thing we have to pay attention is to keep EndRunException. * When overriding an exception, EndRunException has priority. - * @param detectorId Detector Id + * @param configId Detector Id * @param e Exception to set */ - public void setException(String detectorId, Exception e) { - if (e == null || Strings.isEmpty(detectorId)) { + @Override + public void setException(String configId, Exception e) { + if (e == null || Strings.isEmpty(configId)) { return; } - NodeState state = states.computeIfAbsent(detectorId, d -> new NodeState(detectorId, clock)); + NodeState state = states.computeIfAbsent(configId, d -> new NodeState(configId, clock)); Optional exception = state.getException(); if (exception.isPresent()) { Exception higherPriorityException = ExceptionUtil.selectHigherPriorityException(e, exception.get()); @@ -338,6 +328,35 @@ public void setException(String detectorId, Exception e) { state.setException(e); } + /** + * Get a detector's checkpoint and save a flag if we find any so that next time we don't need to do it again + * @param adID the detector's ID + * @param listener listener to handle get request + */ + public void getDetectorCheckpoint(String adID, ActionListener listener) { + NodeState state = states.get(adID); + if (state != null && state.doesCheckpointExists()) { + listener.onResponse(Boolean.TRUE); + return; + } + + GetRequest request = new GetRequest(ADCommonName.CHECKPOINT_INDEX_NAME, SingleStreamModelIdMapper.getRcfModelId(adID, 0)); + + clientUtil.asyncRequest(request, client::get, onGetCheckpointResponse(adID, listener)); + } + + private ActionListener onGetCheckpointResponse(String adID, ActionListener listener) { + return ActionListener.wrap(response -> { + if (response == null || !response.isExists()) { + listener.onResponse(Boolean.FALSE); + } else { + NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); + state.setCheckpointExists(true); + listener.onResponse(Boolean.TRUE); + } + }, listener::onFailure); + } + /** * Whether last cold start for the detector is running * @param adID detector ID @@ -368,17 +387,17 @@ public Releasable markColdStartRunning(String adID) { }; } - public void getAnomalyDetectorJob(String adID, ActionListener> listener) { - NodeState state = states.get(adID); - if (state != null && state.getDetectorJob() != null) { - listener.onResponse(Optional.of(state.getDetectorJob())); + public void getJob(String configID, ActionListener> listener) { + NodeState state = states.get(configID); + if (state != null && state.getJob() != null) { + listener.onResponse(Optional.of(state.getJob())); } else { - GetRequest request = new GetRequest(CommonName.JOB_INDEX, adID); - clientUtil.asyncRequest(request, client::get, onGetDetectorJobResponse(adID, listener)); + GetRequest request = new GetRequest(CommonName.JOB_INDEX, configID); + clientUtil.asyncRequest(request, client::get, onGetJobResponse(configID, listener)); } } - private ActionListener onGetDetectorJobResponse(String adID, ActionListener> listener) { + private ActionListener onGetJobResponse(String configID, ActionListener> listener) { return ActionListener.wrap(response -> { if (response == null || !response.isExists()) { listener.onResponse(Optional.empty()); @@ -386,7 +405,7 @@ private ActionListener onGetDetectorJobResponse(String adID, Action } String xc = response.getSourceAsString(); - LOG.debug("Fetched anomaly detector: {}", xc); + LOG.debug("Fetched config: {}", xc); try ( XContentParser parser = XContentType.JSON @@ -394,13 +413,13 @@ private ActionListener onGetDetectorJobResponse(String adID, Action .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, response.getSourceAsString()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); - NodeState state = states.computeIfAbsent(adID, id -> new NodeState(id, clock)); - state.setDetectorJob(job); + Job job = Job.parse(parser); + NodeState state = states.computeIfAbsent(configID, id -> new NodeState(id, clock)); + state.setJob(job); listener.onResponse(Optional.of(job)); } catch (Exception t) { - LOG.error(new ParameterizedMessage("Fail to parse job {}", adID), t); + LOG.error(new ParameterizedMessage("Fail to parse job {}", configID), t); listener.onResponse(Optional.empty()); } }, listener::onFailure); diff --git a/src/main/java/org/opensearch/timeseries/ProfileUtil.java b/src/main/java/org/opensearch/timeseries/ProfileUtil.java new file mode 100644 index 000000000..82c2338bb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ProfileUtil.java @@ -0,0 +1,109 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.client.Client; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; + +public class ProfileUtil { + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param detectorId detector id + * @param enabledTime the time when AD job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createADRealtimeInittedEverRequest(String detectorId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(AnomalyResult.DETECTOR_ID_FIELD, detectorId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + filterQuery.filter(QueryBuilders.rangeQuery(AnomalyResult.ANOMALY_SCORE_FIELD).gt(0)); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ADCommonName.ANOMALY_RESULT_INDEX_ALIAS); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + /** + * Create search request to check if we have at least 1 anomaly score larger than 0 after AD job enabled time. + * Note this function is only meant to check for status of real time analysis. + * + * @param forecasterId forecaster id + * @param enabledTime the time when forecast job is enabled in milliseconds + * @return the search request + */ + private static SearchRequest createForecastRealtimeInittedEverRequest(String forecasterId, long enabledTime, String resultIndex) { + BoolQueryBuilder filterQuery = new BoolQueryBuilder(); + filterQuery.filter(QueryBuilders.termQuery(ForecastCommonName.FORECASTER_ID_KEY, forecasterId)); + filterQuery.filter(QueryBuilders.rangeQuery(CommonName.EXECUTION_END_TIME_FIELD).gte(enabledTime)); + ExistsQueryBuilder forecastsExistFilter = QueryBuilders.existsQuery(ForecastResult.VALUE_FIELD); + filterQuery.mustNot(forecastsExistFilter); + // Historical analysis result also stored in result index, which has non-null task_id. + // For realtime detection result, we should filter task_id == null + ExistsQueryBuilder taskIdExistsFilter = QueryBuilders.existsQuery(CommonName.TASK_ID_FIELD); + filterQuery.mustNot(taskIdExistsFilter); + + SearchSourceBuilder source = new SearchSourceBuilder().query(filterQuery).size(1); + + SearchRequest request = new SearchRequest(ForecastIndex.RESULT.getIndexName()); + request.source(source); + if (resultIndex != null) { + request.indices(resultIndex); + } + return request; + } + + public static void confirmRealtimeInitStatus( + Config config, + long enabledTime, + Client client, + AnalysisType analysisType, + ActionListener listener + ) { + SearchRequest searchLatestResult = null; + switch (analysisType) { + case AD: + searchLatestResult = createADRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + break; + case FORECAST: + searchLatestResult = createForecastRealtimeInittedEverRequest(config.getId(), enabledTime, config.getCustomResultIndex()); + break; + default: + throw new IllegalArgumentException("Analysis type is not supported, type: : " + analysisType); + } + + client.search(searchLatestResult, listener); + } +} diff --git a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java index 9d3e827eb..3d60656f3 100644 --- a/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/timeseries/TimeSeriesAnalyticsPlugin.java @@ -12,6 +12,8 @@ package org.opensearch.timeseries; import static java.util.Collections.unmodifiableList; +import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.security.AccessController; import java.security.PrivilegedAction; @@ -34,38 +36,27 @@ import org.opensearch.SpecialPermission; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; -import org.opensearch.ad.AnomalyDetectorJobRunner; +import org.opensearch.ad.ADJobProcessor; import org.opensearch.ad.AnomalyDetectorRunner; import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.caching.PriorityCache; -import org.opensearch.ad.cluster.ADClusterEventListener; -import org.opensearch.ad.cluster.ADDataMigrator; -import org.opensearch.ad.cluster.ClusterManagerEventListener; -import org.opensearch.ad.cluster.HashRing; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.caching.ADPriorityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.ratelimit.CheckPointMaintainRequestAdapter; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; import org.opensearch.ad.rest.RestAnomalyDetectorJobAction; import org.opensearch.ad.rest.RestDeleteAnomalyDetectorAction; import org.opensearch.ad.rest.RestDeleteAnomalyResultsAction; @@ -80,17 +71,11 @@ import org.opensearch.ad.rest.RestSearchTopAnomalyResultAction; import org.opensearch.ad.rest.RestStatsAnomalyDetectorAction; import org.opensearch.ad.rest.RestValidateAnomalyDetectorAction; +import org.opensearch.ad.rest.handler.ADIndexJobActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.ADNumericSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.settings.LegacyOpenDistroAnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeCountSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; import org.opensearch.ad.task.ADBatchTaskRunner; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; @@ -112,16 +97,16 @@ import org.opensearch.ad.transport.AnomalyResultTransportAction; import org.opensearch.ad.transport.CronAction; import org.opensearch.ad.transport.CronTransportAction; +import org.opensearch.ad.transport.DeleteADModelAction; +import org.opensearch.ad.transport.DeleteADModelTransportAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorAction; import org.opensearch.ad.transport.DeleteAnomalyDetectorTransportAction; import org.opensearch.ad.transport.DeleteAnomalyResultsAction; import org.opensearch.ad.transport.DeleteAnomalyResultsTransportAction; -import org.opensearch.ad.transport.DeleteModelAction; -import org.opensearch.ad.transport.DeleteModelTransportAction; +import org.opensearch.ad.transport.EntityADResultAction; +import org.opensearch.ad.transport.EntityADResultTransportAction; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileTransportAction; -import org.opensearch.ad.transport.EntityResultAction; -import org.opensearch.ad.transport.EntityResultTransportAction; import org.opensearch.ad.transport.ForwardADTaskAction; import org.opensearch.ad.transport.ForwardADTaskTransportAction; import org.opensearch.ad.transport.GetAnomalyDetectorAction; @@ -154,14 +139,8 @@ import org.opensearch.ad.transport.ThresholdResultTransportAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorAction; import org.opensearch.ad.transport.ValidateAnomalyDetectorTransportAction; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.ad.transport.handler.ADSearchHandler; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.transport.handler.AnomalyResultBulkIndexHandler; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.ad.util.SecurityClientUtil; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodes; @@ -179,8 +158,51 @@ import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.env.Environment; import org.opensearch.env.NodeEnvironment; +import org.opensearch.forecast.ExecuteForecastResultResponseRecorder; +import org.opensearch.forecast.ForecastJobProcessor; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.caching.ForecastPriorityCache; +import org.opensearch.forecast.indices.ForecastIndex; +import org.opensearch.forecast.indices.ForecastIndexManagement; +import org.opensearch.forecast.ml.ForecastCheckpointDao; +import org.opensearch.forecast.ml.ForecastColdStart; +import org.opensearch.forecast.ml.ForecastModelManager; +import org.opensearch.forecast.model.ForecastResult; import org.opensearch.forecast.model.Forecaster; +import org.opensearch.forecast.ratelimit.ForecastCheckpointMaintainWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointReadWorker; +import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker; +import org.opensearch.forecast.ratelimit.ForecastColdEntityWorker; +import org.opensearch.forecast.ratelimit.ForecastColdStartWorker; +import org.opensearch.forecast.ratelimit.ForecastResultWriteWorker; +import org.opensearch.forecast.rest.RestExecuteForecasterAction; +import org.opensearch.forecast.rest.RestForecasterJobAction; +import org.opensearch.forecast.rest.RestGetForecasterAction; +import org.opensearch.forecast.rest.RestIndexForecasterAction; +import org.opensearch.forecast.rest.handler.ForecastIndexJobActionHandler; +import org.opensearch.forecast.settings.ForecastEnabledSetting; +import org.opensearch.forecast.settings.ForecastNumericSetting; import org.opensearch.forecast.settings.ForecastSettings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.forecast.transport.DeleteForecastModelAction; +import org.opensearch.forecast.transport.DeleteForecastModelTransportAction; +import org.opensearch.forecast.transport.EntityForecastResultAction; +import org.opensearch.forecast.transport.EntityForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastResultAction; +import org.opensearch.forecast.transport.ForecastResultBulkAction; +import org.opensearch.forecast.transport.ForecastResultBulkTransportAction; +import org.opensearch.forecast.transport.ForecastResultTransportAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultAction; +import org.opensearch.forecast.transport.ForecastSingleStreamResultTransportAction; +import org.opensearch.forecast.transport.ForecasterJobAction; +import org.opensearch.forecast.transport.ForecasterJobTransportAction; +import org.opensearch.forecast.transport.GetForecasterAction; +import org.opensearch.forecast.transport.GetForecasterTransportAction; +import org.opensearch.forecast.transport.IndexForecasterAction; +import org.opensearch.forecast.transport.IndexForecasterTransportAction; +import org.opensearch.forecast.transport.StopForecasterAction; +import org.opensearch.forecast.transport.StopForecasterTransportAction; +import org.opensearch.forecast.transport.handler.ForecastIndexMemoryPressureAwareResultHandler; import org.opensearch.jobscheduler.spi.JobSchedulerExtension; import org.opensearch.jobscheduler.spi.ScheduledJobParser; import org.opensearch.jobscheduler.spi.ScheduledJobRunner; @@ -196,15 +218,40 @@ import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.function.ThrowingSupplierWrapper; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeCountSupplier; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.IndexUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.watcher.ResourceWatcherService; +import com.amazon.randomcutforest.parkservices.RCFCaster; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; +import com.amazon.randomcutforest.parkservices.state.RCFCasterMapper; +import com.amazon.randomcutforest.parkservices.state.RCFCasterState; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestMapper; import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; import com.amazon.randomcutforest.serialize.json.v1.V1JsonToV3StateConverter; @@ -219,7 +266,7 @@ import io.protostuff.runtime.RuntimeSchema; /** - * Entry point of AD plugin. + * Entry point of time series analytics plugin. */ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, ScriptPlugin, JobSchedulerExtension { @@ -245,21 +292,24 @@ public class TimeSeriesAnalyticsPlugin extends Plugin implements ActionPlugin, S private static Gson gson; private ADIndexManagement anomalyDetectionIndices; + private ForecastIndexManagement forecastIndices; private AnomalyDetectorRunner anomalyDetectorRunner; private Client client; private ClusterService clusterService; private ThreadPool threadPool; - private ADStats adStats; + private Stats timeSeriesStats; private ClientUtil clientUtil; private SecurityClientUtil securityClientUtil; private DiscoveryNodeFilterer nodeFilter; private IndexUtils indexUtils; private ADTaskManager adTaskManager; + private ForecastTaskManager forecastTaskManager; private ADBatchTaskRunner adBatchTaskRunner; // package private for testing GenericObjectPool serializeRCFBufferPool; private NodeStateManager stateManager; private ExecuteADResultResponseRecorder adResultResponseRecorder; + private ExecuteForecastResultResponseRecorder forecastResultResponseRecorder; static { SpecialPermission.check(); @@ -280,14 +330,15 @@ public List getRestHandlers( IndexNameExpressionResolver indexNameExpressionResolver, Supplier nodesInCluster ) { - AnomalyDetectorJobRunner jobRunner = AnomalyDetectorJobRunner.getJobRunnerInstance(); - jobRunner.setClient(client); - jobRunner.setThreadPool(threadPool); - jobRunner.setSettings(settings); - jobRunner.setAnomalyDetectionIndices(anomalyDetectionIndices); - jobRunner.setAdTaskManager(adTaskManager); - jobRunner.setNodeStateManager(stateManager); - jobRunner.setExecuteADResultResponseRecorder(adResultResponseRecorder); + // AD + ADJobProcessor adJobRunner = ADJobProcessor.getInstance(); + adJobRunner.setClient(client); + adJobRunner.setThreadPool(threadPool); + adJobRunner.registerSettings(settings); + adJobRunner.setIndexManagement(anomalyDetectionIndices); + adJobRunner.setTaskManager(adTaskManager); + adJobRunner.setNodeStateManager(stateManager); + adJobRunner.setExecuteResultResponseRecorder(adResultResponseRecorder); RestGetAnomalyDetectorAction restGetAnomalyDetectorAction = new RestGetAnomalyDetectorAction(); RestIndexAnomalyDetectorAction restIndexAnomalyDetectorAction = new RestIndexAnomalyDetectorAction(settings, clusterService); @@ -296,7 +347,7 @@ public List getRestHandlers( RestSearchADTasksAction searchADTasksAction = new RestSearchADTasksAction(); RestDeleteAnomalyDetectorAction deleteAnomalyDetectorAction = new RestDeleteAnomalyDetectorAction(); RestExecuteAnomalyDetectorAction executeAnomalyDetectorAction = new RestExecuteAnomalyDetectorAction(settings, clusterService); - RestStatsAnomalyDetectorAction statsAnomalyDetectorAction = new RestStatsAnomalyDetectorAction(adStats, this.nodeFilter); + RestStatsAnomalyDetectorAction statsAnomalyDetectorAction = new RestStatsAnomalyDetectorAction(timeSeriesStats, this.nodeFilter); RestAnomalyDetectorJobAction anomalyDetectorJobAction = new RestAnomalyDetectorJobAction(settings, clusterService); RestSearchAnomalyDetectorInfoAction searchAnomalyDetectorInfoAction = new RestSearchAnomalyDetectorInfoAction(); RestPreviewAnomalyDetectorAction previewAnomalyDetectorAction = new RestPreviewAnomalyDetectorAction(); @@ -304,8 +355,24 @@ public List getRestHandlers( RestSearchTopAnomalyResultAction searchTopAnomalyResultAction = new RestSearchTopAnomalyResultAction(); RestValidateAnomalyDetectorAction validateAnomalyDetectorAction = new RestValidateAnomalyDetectorAction(settings, clusterService); + // Forecast + RestIndexForecasterAction restIndexForecasterAction = new RestIndexForecasterAction(settings, clusterService); + RestExecuteForecasterAction restExecuteForecasterAction = new RestExecuteForecasterAction(); + RestForecasterJobAction restForecasterJobAction = new RestForecasterJobAction(); + RestGetForecasterAction restGetForecasterAction = new RestGetForecasterAction(); + + ForecastJobProcessor forecastJobRunner = ForecastJobProcessor.getInstance(); + forecastJobRunner.setClient(client); + forecastJobRunner.setThreadPool(threadPool); + forecastJobRunner.registerSettings(settings); + forecastJobRunner.setIndexManagement(forecastIndices); + forecastJobRunner.setTaskManager(forecastTaskManager); + forecastJobRunner.setNodeStateManager(stateManager); + forecastJobRunner.setExecuteResultResponseRecorder(forecastResultResponseRecorder); + return ImmutableList .of( + // AD restGetAnomalyDetectorAction, restIndexAnomalyDetectorAction, searchAnomalyDetectorAction, @@ -319,7 +386,12 @@ public List getRestHandlers( previewAnomalyDetectorAction, deleteAnomalyResultsAction, searchTopAnomalyResultAction, - validateAnomalyDetectorAction + validateAnomalyDetectorAction, + // Forecast + restIndexForecasterAction, + restExecuteForecasterAction, + restForecasterJobAction, + restGetForecasterAction ); } @@ -342,41 +414,64 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - ADEnabledSetting.getInstance().init(clusterService); - ADNumericSetting.getInstance().init(clusterService); + // ===================== + // Common components + // ===================== this.client = client; this.threadPool = threadPool; Settings settings = environment.settings(); - Throttler throttler = new Throttler(getClock()); - this.clientUtil = new ClientUtil(settings, client, throttler, threadPool); - this.indexUtils = new IndexUtils(client, clientUtil, clusterService, indexNameExpressionResolver); + this.clientUtil = new ClientUtil(client); + this.indexUtils = new IndexUtils(clusterService, indexNameExpressionResolver); this.nodeFilter = new DiscoveryNodeFilterer(clusterService); - // convert from checked IOException to unchecked RuntimeException - this.anomalyDetectionIndices = ThrowingSupplierWrapper - .throwingSupplierWrapper( - () -> new ADIndexManagement( - client, - clusterService, - threadPool, - settings, - nodeFilter, - TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES - ) - ) - .get(); this.clusterService = clusterService; - Imputer imputer = new LinearUniformImputer(true); + + JvmService jvmService = new JvmService(environment.settings()); + RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + rcfMapper.setSaveExecutorContextEnabled(true); + rcfMapper.setSaveTreeStateEnabled(true); + rcfMapper.setPartialTreeStateEnabled(true); + V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); + + CircuitBreakerService circuitBreakerService = new CircuitBreakerService(jvmService).init(); + + long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + + serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { + @Override + public GenericObjectPool run() { + return new GenericObjectPool<>(new BasePooledObjectFactory() { + @Override + public LinkedBuffer create() throws Exception { + return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); + } + + @Override + public PooledObject wrap(LinkedBuffer obj) { + return new DefaultPooledObject<>(obj); + } + }); + } + }); + serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMinIdle(0); + serializeRCFBufferPool.setBlockWhenExhausted(false); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); + stateManager = new NodeStateManager( client, xContentRegistry, settings, clientUtil, getClock(), - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - clusterService + TimeSeriesSettings.HOURLY_MAINTENANCE, + clusterService, + TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + TimeSeriesSettings.BACKOFF_MINUTES ); securityClientUtil = new SecurityClientUtil(stateManager, settings); + SearchFeatureDao searchFeatureDao = new SearchFeatureDao( client, xContentRegistry, @@ -384,75 +479,54 @@ public Collection createComponents( securityClientUtil, settings, clusterService, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE - ); - - JvmService jvmService = new JvmService(environment.settings()); - RandomCutForestMapper mapper = new RandomCutForestMapper(); - mapper.setSaveExecutorContextEnabled(true); - mapper.setSaveTreeStateEnabled(true); - mapper.setPartialTreeStateEnabled(true); - V1JsonToV3StateConverter converter = new V1JsonToV3StateConverter(); - - double modelMaxSizePercent = AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings); - - ADCircuitBreakerService adCircuitBreakerService = new ADCircuitBreakerService(jvmService).init(); - - MemoryTracker memoryTracker = new MemoryTracker( - jvmService, - modelMaxSizePercent, - AnomalyDetectorSettings.DESIRED_MODEL_SIZE_PERCENTAGE, - clusterService, - adCircuitBreakerService + TimeSeriesSettings.NUM_SAMPLES_PER_TREE ); FeatureManager featureManager = new FeatureManager( searchFeatureDao, imputer, getClock(), - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, AD_THREAD_POOL_NAME ); - long heapSizeBytes = JvmInfo.jvmInfo().getMem().getHeapMax().getBytes(); + Random random = new Random(42); - serializeRCFBufferPool = AccessController.doPrivileged(new PrivilegedAction>() { - @Override - public GenericObjectPool run() { - return new GenericObjectPool<>(new BasePooledObjectFactory() { - @Override - public LinkedBuffer create() throws Exception { - return LinkedBuffer.allocate(AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES); - } + // ===================== + // AD components + // ===================== + ADEnabledSetting.getInstance().init(clusterService); + ADNumericSetting.getInstance().init(clusterService); + // convert from checked IOException to unchecked RuntimeException + this.anomalyDetectionIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ADIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + TimeSeriesSettings.MAX_UPDATE_RETRY_TIMES + ) + ) + .get(); - @Override - public PooledObject wrap(LinkedBuffer obj) { - return new DefaultPooledObject<>(obj); - } - }); - } - }); - serializeRCFBufferPool.setMaxTotal(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMaxIdle(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMinIdle(0); - serializeRCFBufferPool.setBlockWhenExhausted(false); - serializeRCFBufferPool.setTimeBetweenEvictionRuns(AnomalyDetectorSettings.HOURLY_MAINTENANCE); + double adModelMaxSizePercent = AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings); + + MemoryTracker adMemoryTracker = new MemoryTracker(jvmService, adModelMaxSizePercent, clusterService, circuitBreakerService); - CheckpointDao checkpoint = new CheckpointDao( + ADCheckpointDao adCheckpoint = new ADCheckpointDao( client, clientUtil, - ADCommonName.CHECKPOINT_INDEX_NAME, gson, - mapper, + rcfMapper, converter, new ThresholdedRandomCutForestMapper(), AccessController @@ -462,273 +536,659 @@ public PooledObject wrap(LinkedBuffer obj) { ), HybridThresholdingModel.class, anomalyDetectionIndices, - AnomalyDetectorSettings.MAX_CHECKPOINT_BYTES, + TimeSeriesSettings.MAX_CHECKPOINT_BYTES, serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, - 1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, + 1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + getClock() ); - Random random = new Random(42); - - CacheProvider cacheProvider = new CacheProvider(); + ADCacheProvider adCacheProvider = new ADCacheProvider(); - CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( - cacheProvider, - checkpoint, - ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - getClock(), - clusterService, - settings - ); + CheckPointMaintainRequestAdapter adAdapter = + new CheckPointMaintainRequestAdapter<>( + adCheckpoint, + ADCommonName.CHECKPOINT_INDEX_NAME, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + adCacheProvider + ); - CheckpointWriteWorker checkpointWriteQueue = new CheckpointWriteWorker( + ADCheckpointWriteWorker adCheckpointWriteQueue = new ADCheckpointWriteWorker( heapSizeBytes, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, - checkpoint, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adCheckpoint, ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + TimeSeriesSettings.HOURLY_MAINTENANCE ); - CheckpointMaintainWorker checkpointMaintainQueue = new CheckpointMaintainWorker( + ADCheckpointMaintainWorker adCheckpointMaintainQueue = new ADCheckpointMaintainWorker( heapSizeBytes, - AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointWriteQueue, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + adCheckpointWriteQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - adapter + adAdapter::convert ); - EntityCache cache = new PriorityCache( - checkpoint, - AnomalyDetectorSettings.DEDICATED_CACHE_SIZE.get(settings), - AnomalyDetectorSettings.CHECKPOINT_TTL, + ADPriorityCache adPriorityCache = new ADPriorityCache( + adCheckpoint, + AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, - memoryTracker, - AnomalyDetectorSettings.NUM_TREES, + adMemoryTracker, + TimeSeriesSettings.NUM_TREES, getClock(), clusterService, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, - checkpointWriteQueue, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointMaintainQueue, + AD_THREAD_POOL_NAME, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, settings, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + adCheckpointWriteQueue, + adCheckpointMaintainQueue ); - cacheProvider.set(cache); + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + adCacheProvider.set(adPriorityCache); - EntityColdStarter entityColdStarter = new EntityColdStarter( + ADEntityColdStart adEntityColdStarter = new ADEntityColdStart( getClock(), threadPool, stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - imputer, searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, - settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.HOURLY_MAINTENANCE, + adCheckpointWriteQueue, + TimeSeriesSettings.MAX_COLD_START_ROUNDS, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()) ); - EntityColdStartWorker coldstartQueue = new EntityColdStartWorker( + ADColdStartWorker adColdstartQueue = new ADColdStartWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, - entityColdStarter, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adEntityColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - cacheProvider + adPriorityCache ); - ModelManager modelManager = new ModelManager( - checkpoint, + ADModelManager adModelManager = new ADModelManager( + adCheckpoint, getClock(), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - entityColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + adEntityColdStarter, featureManager, - memoryTracker, + adMemoryTracker, settings, clusterService ); - MultiEntityResultHandler multiEntityResultHandler = new MultiEntityResultHandler( + ADIndexMemoryPressureAwareResultHandler adIndexMemoryPressureAwareResultHandler = new ADIndexMemoryPressureAwareResultHandler( client, + anomalyDetectionIndices + ); + + ADResultWriteWorker adResultWriteQueue = new ADResultWriteWorker( + heapSizeBytes, + TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adIndexMemoryPressureAwareResultHandler, + xContentRegistry, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE + ); + + ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + adModelManager, + adCheckpoint, + adColdstartQueue, + adResultWriteQueue, + stateManager, anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService + adCacheProvider, + TimeSeriesSettings.HOURLY_MAINTENANCE, + adCheckpointWriteQueue, + timeSeriesStats ); - ResultWriteWorker resultWriteQueue = new ResultWriteWorker( + ADColdEntityWorker adColdEntityQueue = new ADColdEntityWorker( heapSizeBytes, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, - multiEntityResultHandler, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + adCheckpointReadQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager + ); + + ADDataMigrator adDataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); + + anomalyDetectorRunner = new AnomalyDetectorRunner(adModelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + + ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, adMemoryTracker); + + ResultBulkIndexingHandler anomalyResultBulkIndexHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); + + ResultBulkIndexingHandler anomalyResultHandler = new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ANOMALY_RESULT_INDEX_ALIAS, + anomalyDetectionIndices, + this.clientUtil, + this.indexUtils, + clusterService, + AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF + ); + + adResultResponseRecorder = new ExecuteADResultResponseRecorder( + anomalyDetectionIndices, + anomalyResultHandler, + adTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + adTaskCacheManager, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + + ADIndexJobActionHandler adIndexJobActionHandler = new ADIndexJobActionHandler( + client, + anomalyDetectionIndices, xContentRegistry, + adTaskManager, + adResultResponseRecorder, stateManager, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + settings ); - Map> stats = ImmutableMap - .>builder() - .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put( - StatNames.MODEL_INFORMATION.getName(), - new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) - ) - .put( - StatNames.ANOMALY_DETECTORS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) - ) - .put( - StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) - ) - .put( - StatNames.MODELS_CHECKPOINT_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) - ) - .put( - StatNames.ANOMALY_DETECTION_JOB_INDEX_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) - ) - .put( - StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), - new ADStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + // ===================== + // forecast components + // ===================== + ForecastEnabledSetting.getInstance().init(clusterService); + ForecastNumericSetting.getInstance().init(clusterService); + + forecastIndices = ThrowingSupplierWrapper + .throwingSupplierWrapper( + () -> new ForecastIndexManagement( + client, + clusterService, + threadPool, + settings, + nodeFilter, + ForecastSettings.FORECAST_MAX_UPDATE_RETRY_TIMES + ) ) - .put(StatNames.DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.MULTI_ENTITY_DETECTOR_COUNT.getName(), new ADStat<>(true, new SettableSupplier())) - .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .put(StatNames.MODEL_COUNT.getName(), new ADStat<>(false, new ModelsOnNodeCountSupplier(modelManager, cacheProvider))) - .put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())) - .build(); + .get(); + + double forecastModelMaxSizePercent = ForecastSettings.FORECAST_MODEL_MAX_SIZE_PERCENTAGE.get(settings); - adStats = new ADStats(stats); + MemoryTracker forecastMemoryTracker = new MemoryTracker( + jvmService, + forecastModelMaxSizePercent, + clusterService, + circuitBreakerService + ); - CheckpointReadWorker checkpointReadQueue = new CheckpointReadWorker( + ForecastCheckpointDao forecastCheckpoint = new ForecastCheckpointDao( + client, + clientUtil, + gson, + TimeSeriesSettings.MAX_CHECKPOINT_BYTES, + serializeRCFBufferPool, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, + forecastIndices, + new RCFCasterMapper(), + AccessController.doPrivileged((PrivilegedAction>) () -> RuntimeSchema.getSchema(RCFCasterState.class)), + getClock() + ); + + ForecastCacheProvider forecastCacheProvider = new ForecastCacheProvider(); + + CheckPointMaintainRequestAdapter forecastAdapter = + new CheckPointMaintainRequestAdapter( + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + getClock(), + clusterService, + settings, + forecastCacheProvider + ); + + ForecastCheckpointWriteWorker forecastCheckpointWriteQueue = new ForecastCheckpointWriteWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, - modelManager, - checkpoint, - coldstartQueue, - resultWriteQueue, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastCheckpoint, + ForecastIndex.CHECKPOINT.getIndexName(), + TimeSeriesSettings.HOURLY_MAINTENANCE, stateManager, - anomalyDetectionIndices, - cacheProvider, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, - checkpointWriteQueue, - adStats + TimeSeriesSettings.HOURLY_MAINTENANCE ); - ColdEntityWorker coldEntityQueue = new ColdEntityWorker( + ForecastCheckpointMaintainWorker forecastCheckpointMaintainQueue = new ForecastCheckpointMaintainWorker( heapSizeBytes, - AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_MAINTAIN_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, random, - adCircuitBreakerService, + circuitBreakerService, threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, getClock(), - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - checkpointReadQueue, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + forecastCheckpointWriteQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastAdapter::convert + ); + + ForecastPriorityCache forecastPriorityCache = new ForecastPriorityCache( + forecastCheckpoint, + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE.get(settings), + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, + AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, + adMemoryTracker, + TimeSeriesSettings.NUM_TREES, + getClock(), + clusterService, + TimeSeriesSettings.HOURLY_MAINTENANCE, + threadPool, + FORECAST_THREAD_POOL_NAME, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + settings, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + forecastCheckpointWriteQueue, + forecastCheckpointMaintainQueue + ); + + // cache provider allows us to break circular dependency among PriorityCache, CacheBuffer, + // CheckPointMaintainRequestAdapter, and CheckpointMaintainWorker + forecastCacheProvider.set(forecastPriorityCache); + + ForecastColdStart forecastColdStarter = new ForecastColdStart( + getClock(), + threadPool, + stateManager, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + searchFeatureDao, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, + featureManager, + TimeSeriesSettings.HOURLY_MAINTENANCE, + forecastCheckpointWriteQueue, + (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()), + -1, // no hard coded random seed + TimeSeriesSettings.TIME_DECAY, + -1, // interpolation is disabled so we don't need to specify the number of sampled points + TimeSeriesSettings.MAX_COLD_START_ROUNDS + ); + + ForecastColdStartWorker forecastColdstartQueue = new ForecastColdStartWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastColdStarter, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager, + forecastPriorityCache + ); + + ForecastModelManager forecastModelManager = new ForecastModelManager( + forecastCheckpoint, + getClock(), + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + forecastColdStarter, + forecastMemoryTracker, + featureManager + ); + + ForecastIndexMemoryPressureAwareResultHandler forecastIndexMemoryPressureAwareResultHandler = + new ForecastIndexMemoryPressureAwareResultHandler(client, forecastIndices); + + ForecastResultWriteWorker forecastResultWriteQueue = new ForecastResultWriteWorker( + heapSizeBytes, + TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastIndexMemoryPressureAwareResultHandler, + xContentRegistry, + stateManager, + TimeSeriesSettings.HOURLY_MAINTENANCE + ); + + ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, + forecastModelManager, + forecastCheckpoint, + forecastColdstartQueue, + forecastResultWriteQueue, + stateManager, + forecastIndices, + forecastCacheProvider, + TimeSeriesSettings.HOURLY_MAINTENANCE, + forecastCheckpointWriteQueue, + timeSeriesStats + ); + + ForecastColdEntityWorker forecastColdEntityQueue = new ForecastColdEntityWorker( + heapSizeBytes, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + clusterService, + random, + circuitBreakerService, + threadPool, + settings, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, + getClock(), + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + forecastCheckpointReadQueue, + TimeSeriesSettings.HOURLY_MAINTENANCE, + stateManager + ); + + TaskCacheManager forecastTaskCacheManager = new TaskCacheManager(settings, clusterService); + + forecastTaskManager = new ForecastTaskManager( + forecastTaskCacheManager, + client, + xContentRegistry, + forecastIndices, + clusterService, + settings, + threadPool, stateManager ); - ADDataMigrator dataMigrator = new ADDataMigrator(client, clusterService, xContentRegistry, anomalyDetectionIndices); - HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, dataMigrator, modelManager); + // TODO: do we need it in forecast backtesting? + // ResultBulkIndexingHandler forecastResultBulkIndexHandler = new + // ResultBulkIndexingHandler<>( + // client, + // settings, + // threadPool, + // ForecastCommonName.FORECAST_RESULT_INDEX_ALIAS, + // forecastIndices, + // this.clientUtil, + // this.indexUtils, + // clusterService, + // ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + // ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF + // ); - anomalyDetectorRunner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); + ResultBulkIndexingHandler forecastResultHandler = + new ResultBulkIndexingHandler<>( + client, + settings, + threadPool, + ForecastIndex.RESULT.getIndexName(), + forecastIndices, + this.clientUtil, + this.indexUtils, + clusterService, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF + ); - ADTaskCacheManager adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); + forecastResultResponseRecorder = new ExecuteForecastResultResponseRecorder( + forecastIndices, + forecastResultHandler, + forecastTaskManager, + nodeFilter, + threadPool, + client, + stateManager, + forecastTaskCacheManager, + TimeSeriesSettings.NUM_MIN_SAMPLES + ); + + ForecastIndexJobActionHandler forecastIndexJobActionHandler = new ForecastIndexJobActionHandler( + client, + forecastIndices, + xContentRegistry, + forecastTaskManager, + forecastResultResponseRecorder, + stateManager, + settings + ); + + // ===================== + // common components, need AD/forecasting components to initialize + // ===================== + Map> stats = ImmutableMap + .>builder() + // ad stats + .put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.ANOMALY_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.ANOMALY_RESULT_INDEX_ALIAS)) + ) + .put( + StatNames.AD_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.CHECKPOINT_INDEX_NAME)) + ) + .put( + StatNames.ANOMALY_DETECTION_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ADCommonName.DETECTION_STATE_INDEX)) + ) + .put(StatNames.DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_DETECTOR_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.AD_EXECUTING_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_CANCELED_BATCH_TASK_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_BATCH_TASK_FAILURE_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + // forecast stats + .put(StatNames.FORECAST_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put(StatNames.FORECAST_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + .put( + StatNames.FORECAST_RESULTS_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.RESULT.getIndexName())) + ) + .put( + StatNames.FORECAST_MODELS_CHECKPOINT_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.CHECKPOINT.getIndexName())) + ) + .put( + StatNames.FORECAST_STATE_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, ForecastIndex.STATE.getIndexName())) + ) + .put(StatNames.FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.SINGLE_STREAM_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.HC_FORECASTER_COUNT.getName(), new TimeSeriesStat<>(true, new SettableSupplier())) + .put(StatNames.FORECAST_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())) + // combined stats + .put( + StatNames.MODEL_INFORMATION.getName(), + new TimeSeriesStat<>( + false, + new ModelsOnNodeSupplier(adModelManager, adCacheProvider, forecastCacheProvider, settings, clusterService) + ) + ) + .put( + StatNames.CONFIG_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.CONFIG_INDEX)) + ) + .put( + StatNames.JOB_INDEX_STATUS.getName(), + new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, CommonName.JOB_INDEX)) + ) + .put( + StatNames.MODEL_COUNT.getName(), + new TimeSeriesStat<>(false, new ModelsOnNodeCountSupplier(adModelManager, adCacheProvider, forecastCacheProvider)) + ) + .build(); + + timeSeriesStats = new Stats(stats); + HashRing hashRing = new HashRing(nodeFilter, getClock(), settings, client, clusterService, adDataMigrator, adModelManager); adTaskManager = new ADTaskManager( settings, clusterService, @@ -738,64 +1198,33 @@ public PooledObject wrap(LinkedBuffer obj) { nodeFilter, hashRing, adTaskCacheManager, - threadPool - ); - AnomalyResultBulkIndexHandler anomalyResultBulkIndexHandler = new AnomalyResultBulkIndexHandler( - client, - settings, threadPool, - this.clientUtil, - this.indexUtils, - clusterService, - anomalyDetectionIndices + stateManager ); + adBatchTaskRunner = new ADBatchTaskRunner( settings, threadPool, clusterService, client, securityClientUtil, - adCircuitBreakerService, + circuitBreakerService, featureManager, adTaskManager, anomalyDetectionIndices, - adStats, + timeSeriesStats, anomalyResultBulkIndexHandler, adTaskCacheManager, searchFeatureDao, hashRing, - modelManager - ); - - ADSearchHandler adSearchHandler = new ADSearchHandler(settings, clusterService, client); - - AnomalyIndexHandler anomalyResultHandler = new AnomalyIndexHandler( - client, - settings, - threadPool, - ADCommonName.ANOMALY_RESULT_INDEX_ALIAS, - anomalyDetectionIndices, - this.clientUtil, - this.indexUtils, - clusterService - ); - - adResultResponseRecorder = new ExecuteADResultResponseRecorder( - anomalyDetectionIndices, - anomalyResultHandler, - adTaskManager, - nodeFilter, - threadPool, - client, - stateManager, - adTaskCacheManager, - AnomalyDetectorSettings.NUM_MIN_SAMPLES + adModelManager ); // return objects used by Guice to inject dependencies for e.g., // transport action handler constructors return ImmutableList .of( + // AD components anomalyDetectionIndices, anomalyDetectorRunner, searchFeatureDao, @@ -804,11 +1233,11 @@ public PooledObject wrap(LinkedBuffer obj) { jvmService, hashRing, featureManager, - modelManager, + adModelManager, stateManager, - new ADClusterEventListener(clusterService, hashRing), - adCircuitBreakerService, - adStats, + new ClusterEventListener(clusterService, hashRing), + circuitBreakerService, + timeSeriesStats, new ClusterManagerEventListener( clusterService, threadPool, @@ -816,24 +1245,40 @@ public PooledObject wrap(LinkedBuffer obj) { getClock(), clientUtil, nodeFilter, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, settings ), nodeFilter, - multiEntityResultHandler, - checkpoint, - cacheProvider, + adIndexMemoryPressureAwareResultHandler, + adCheckpoint, + adCacheProvider, adTaskManager, adBatchTaskRunner, adSearchHandler, - coldstartQueue, - resultWriteQueue, - checkpointReadQueue, - checkpointWriteQueue, - coldEntityQueue, - entityColdStarter, + adColdstartQueue, + adResultWriteQueue, + adCheckpointReadQueue, + adCheckpointWriteQueue, + adColdEntityQueue, + adEntityColdStarter, adTaskCacheManager, - adResultResponseRecorder + adResultResponseRecorder, + adIndexJobActionHandler, + // forecast components + forecastIndices, + forecastModelManager, + forecastIndexMemoryPressureAwareResultHandler, + forecastCheckpoint, + forecastCacheProvider, + forecastColdstartQueue, + forecastResultWriteQueue, + forecastCheckpointReadQueue, + forecastCheckpointWriteQueue, + forecastColdEntityQueue, + forecastColdStarter, + forecastTaskManager, + forecastIndexJobActionHandler, + forecastTaskCacheManager ); } @@ -865,14 +1310,33 @@ public List> getExecutorBuilders(Settings settings) { Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 8), TimeValue.timeValueMinutes(10), AD_THREAD_POOL_PREFIX + AD_BATCH_TASK_THREAD_POOL_NAME + ), + new ScalingExecutorBuilder( + FORECAST_THREAD_POOL_NAME, + 1, + // HCAD can be heavy after supporting 1 million entities. + // Limit to use at most half of the processors. + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 2), + TimeValue.timeValueMinutes(10), + FORECAST_THREAD_POOL_PREFIX + FORECAST_THREAD_POOL_NAME + ), + new ScalingExecutorBuilder( + FORECAST_BATCH_TASK_THREAD_POOL_NAME, + 1, + Math.max(1, OpenSearchExecutors.allocatedProcessors(settings) / 8), + TimeValue.timeValueMinutes(10), + FORECAST_THREAD_POOL_PREFIX + FORECAST_BATCH_TASK_THREAD_POOL_NAME ) ); } @Override public List> getSettings() { - List> enabledSetting = ADEnabledSetting.getInstance().getSettings(); - List> numericSetting = ADNumericSetting.getInstance().getSettings(); + List> adEnabledSetting = ADEnabledSetting.getInstance().getSettings(); + List> adNumericSetting = ADNumericSetting.getInstance().getSettings(); + + List> forecastEnabledSetting = ForecastEnabledSetting.getInstance().getSettings(); + List> forecastNumericSetting = ForecastNumericSetting.getInstance().getSettings(); List> systemSetting = ImmutableList .of( @@ -881,7 +1345,7 @@ public List> getSettings() { // ====================================== // HCAD cache LegacyOpenDistroAnomalyDetectorSettings.MAX_CACHE_MISS_HANDLING_PER_SECOND, - AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE, // Detector config LegacyOpenDistroAnomalyDetectorSettings.DETECTION_INTERVAL, LegacyOpenDistroAnomalyDetectorSettings.DETECTION_WINDOW_DELAY, @@ -896,10 +1360,10 @@ public List> getSettings() { LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES, LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_INITIAL_DELAY, LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF, - AnomalyDetectorSettings.REQUEST_TIMEOUT, - AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, - AnomalyDetectorSettings.COOLDOWN_MINUTES, - AnomalyDetectorSettings.BACKOFF_MINUTES, + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE, + AnomalyDetectorSettings.AD_COOLDOWN_MINUTES, + AnomalyDetectorSettings.AD_BACKOFF_MINUTES, AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF, // result index rollover @@ -915,15 +1379,15 @@ public List> getSettings() { LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, LegacyOpenDistroAnomalyDetectorSettings.INDEX_PRESSURE_SOFT_LIMIT, LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS, - AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, - AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS, AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT, AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT, AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS, // Security LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, // Historical LegacyOpenDistroAnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, LegacyOpenDistroAnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, @@ -939,35 +1403,44 @@ public List> getSettings() { // rate limiting AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, - AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_CHECKPOINT_SAVING_FREQ, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, // query limit LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW, - AnomalyDetectorSettings.PAGE_SIZE, + AnomalyDetectorSettings.AD_PAGE_SIZE, // clean resource AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, // stats/profile API - AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE, + AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE, // ====================================== // Forecast settings // ====================================== + // HC forecasting cache + ForecastSettings.FORECAST_DEDICATED_CACHE_SIZE, + // config parameters + ForecastSettings.FORECAST_INTERVAL, + ForecastSettings.FORECAST_WINDOW_DELAY, + // Fault tolerance + ForecastSettings.FORECAST_BACKOFF_MINUTES, + ForecastSettings.FORECAST_BACKOFF_INITIAL_DELAY, + ForecastSettings.FORECAST_MAX_RETRY_FOR_BACKOFF, // result index rollover ForecastSettings.FORECAST_RESULT_HISTORY_MAX_DOCS_PER_SHARD, ForecastSettings.FORECAST_RESULT_HISTORY_RETENTION_PERIOD, @@ -979,11 +1452,58 @@ public List> getSettings() { // ForecastSettings.FORECAST_MAX_HC_FORECASTERS, ForecastSettings.FORECAST_INDEX_PRESSURE_SOFT_LIMIT, ForecastSettings.FORECAST_INDEX_PRESSURE_HARD_LIMIT, - ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS + ForecastSettings.FORECAST_MAX_PRIMARY_SHARDS, + // restful apis + ForecastSettings.FORECAST_REQUEST_TIMEOUT, + // resource constraint + ForecastSettings.MAX_SINGLE_STREAM_FORECASTERS, + ForecastSettings.MAX_HC_FORECASTERS, + // Security + ForecastSettings.FORECAST_FILTER_BY_BACKEND_ROLES, + // Historical + ForecastSettings.MAX_OLD_TASK_DOCS_PER_FORECASTER, + // rate limiting + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_COLD_START_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_CONCURRENCY, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_BATCH_SIZE, + ForecastSettings.FORECAST_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_COLD_START_QUEUE_MAX_HEAP_PERCENT, + ForecastSettings.FORECAST_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, + ForecastSettings.FORECAST_CHECKPOINT_SAVING_FREQ, + ForecastSettings.FORECAST_CHECKPOINT_TTL, + // query limit + ForecastSettings.FORECAST_MAX_ENTITIES_PER_INTERVAL, + ForecastSettings.FORECAST_PAGE_SIZE, + // stats/profile API + ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE, + // ====================================== + // Common settings + // ====================================== + // Fault tolerance + TimeSeriesSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, + TimeSeriesSettings.BACKOFF_MINUTES, + TimeSeriesSettings.COOLDOWN_MINUTES, + // tasks + TimeSeriesSettings.MAX_CACHED_DELETED_TASKS ); return unmodifiableList( Stream - .of(enabledSetting.stream(), systemSetting.stream(), numericSetting.stream()) + .of( + adEnabledSetting.stream(), + forecastEnabledSetting.stream(), + systemSetting.stream(), + adNumericSetting.stream(), + forecastNumericSetting.stream() + ) .reduce(Stream::concat) .orElseGet(Stream::empty) .collect(Collectors.toList()) @@ -997,7 +1517,7 @@ public List getNamedXContent() { AnomalyDetector.XCONTENT_REGISTRY, AnomalyResult.XCONTENT_REGISTRY, DetectorInternalState.XCONTENT_REGISTRY, - AnomalyDetectorJob.XCONTENT_REGISTRY, + Job.XCONTENT_REGISTRY, Forecaster.XCONTENT_REGISTRY ); } @@ -1009,7 +1529,8 @@ public List getNamedXContent() { public List> getActions() { return Arrays .asList( - new ActionHandler<>(DeleteModelAction.INSTANCE, DeleteModelTransportAction.class), + // AD + new ActionHandler<>(DeleteADModelAction.INSTANCE, DeleteADModelTransportAction.class), new ActionHandler<>(StopDetectorAction.INSTANCE, StopDetectorTransportAction.class), new ActionHandler<>(RCFResultAction.INSTANCE, RCFResultTransportAction.class), new ActionHandler<>(ThresholdResultAction.INSTANCE, ThresholdResultTransportAction.class), @@ -1027,7 +1548,7 @@ public List getNamedXContent() { new ActionHandler<>(IndexAnomalyDetectorAction.INSTANCE, IndexAnomalyDetectorTransportAction.class), new ActionHandler<>(AnomalyDetectorJobAction.INSTANCE, AnomalyDetectorJobTransportAction.class), new ActionHandler<>(ADResultBulkAction.INSTANCE, ADResultBulkTransportAction.class), - new ActionHandler<>(EntityResultAction.INSTANCE, EntityResultTransportAction.class), + new ActionHandler<>(EntityADResultAction.INSTANCE, EntityADResultTransportAction.class), new ActionHandler<>(EntityProfileAction.INSTANCE, EntityProfileTransportAction.class), new ActionHandler<>(SearchAnomalyDetectorInfoAction.INSTANCE, SearchAnomalyDetectorInfoTransportAction.class), new ActionHandler<>(PreviewAnomalyDetectorAction.INSTANCE, PreviewAnomalyDetectorTransportAction.class), @@ -1038,7 +1559,17 @@ public List getNamedXContent() { new ActionHandler<>(ForwardADTaskAction.INSTANCE, ForwardADTaskTransportAction.class), new ActionHandler<>(DeleteAnomalyResultsAction.INSTANCE, DeleteAnomalyResultsTransportAction.class), new ActionHandler<>(SearchTopAnomalyResultAction.INSTANCE, SearchTopAnomalyResultTransportAction.class), - new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class) + new ActionHandler<>(ValidateAnomalyDetectorAction.INSTANCE, ValidateAnomalyDetectorTransportAction.class), + // forecast + new ActionHandler<>(IndexForecasterAction.INSTANCE, IndexForecasterTransportAction.class), + new ActionHandler<>(ForecastResultAction.INSTANCE, ForecastResultTransportAction.class), + new ActionHandler<>(EntityForecastResultAction.INSTANCE, EntityForecastResultTransportAction.class), + new ActionHandler<>(ForecastResultBulkAction.INSTANCE, ForecastResultBulkTransportAction.class), + new ActionHandler<>(ForecastSingleStreamResultAction.INSTANCE, ForecastSingleStreamResultTransportAction.class), + new ActionHandler<>(ForecasterJobAction.INSTANCE, ForecasterJobTransportAction.class), + new ActionHandler<>(StopForecasterAction.INSTANCE, StopForecasterTransportAction.class), + new ActionHandler<>(DeleteForecastModelAction.INSTANCE, DeleteForecastModelTransportAction.class), + new ActionHandler<>(GetForecasterAction.INSTANCE, GetForecasterTransportAction.class) ); } @@ -1054,14 +1585,14 @@ public String getJobIndex() { @Override public ScheduledJobRunner getJobRunner() { - return AnomalyDetectorJobRunner.getJobRunnerInstance(); + return JobRunner.getJobRunnerInstance(); } @Override public ScheduledJobParser getJobParser() { return (parser, id, jobDocVersion) -> { XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - return AnomalyDetectorJob.parse(parser); + return Job.parse(parser); }; } diff --git a/src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreaker.java similarity index 91% rename from src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java rename to src/main/java/org/opensearch/timeseries/breaker/CircuitBreaker.java index 2825d2f98..5258ac64e 100644 --- a/src/main/java/org/opensearch/ad/breaker/CircuitBreaker.java +++ b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreaker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.breaker; +package org.opensearch.timeseries.breaker; /** * An interface for circuit breaker. diff --git a/src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java similarity index 87% rename from src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java rename to src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java index 9c9ab5b34..018834fe6 100644 --- a/src/main/java/org/opensearch/ad/breaker/ADCircuitBreakerService.java +++ b/src/main/java/org/opensearch/timeseries/breaker/CircuitBreakerService.java @@ -9,13 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.breaker; +package org.opensearch.timeseries.breaker; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ad.breaker.BreakerName; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.monitor.jvm.JvmService; @@ -24,19 +25,19 @@ * * This service registers internal system breakers and provide API for users to register their own breakers. */ -public class ADCircuitBreakerService { +public class CircuitBreakerService { private final ConcurrentMap breakers = new ConcurrentHashMap<>(); private final JvmService jvmService; - private static final Logger logger = LogManager.getLogger(ADCircuitBreakerService.class); + private static final Logger logger = LogManager.getLogger(CircuitBreakerService.class); /** * Constructor. * * @param jvmService jvm info */ - public ADCircuitBreakerService(JvmService jvmService) { + public CircuitBreakerService(JvmService jvmService) { this.jvmService = jvmService; } @@ -67,7 +68,7 @@ public CircuitBreaker getBreaker(String name) { * * @return ADCircuitBreakerService */ - public ADCircuitBreakerService init() { + public CircuitBreakerService init() { // Register memory circuit breaker registerBreaker(BreakerName.MEM.getName(), new MemoryCircuitBreaker(this.jvmService)); logger.info("Registered memory breaker."); diff --git a/src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java b/src/main/java/org/opensearch/timeseries/breaker/MemoryCircuitBreaker.java similarity index 95% rename from src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java rename to src/main/java/org/opensearch/timeseries/breaker/MemoryCircuitBreaker.java index c4628c639..cf4b47d71 100644 --- a/src/main/java/org/opensearch/ad/breaker/MemoryCircuitBreaker.java +++ b/src/main/java/org/opensearch/timeseries/breaker/MemoryCircuitBreaker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.breaker; +package org.opensearch.timeseries.breaker; import org.opensearch.monitor.jvm.JvmService; diff --git a/src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java b/src/main/java/org/opensearch/timeseries/breaker/ThresholdCircuitBreaker.java similarity index 94% rename from src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java rename to src/main/java/org/opensearch/timeseries/breaker/ThresholdCircuitBreaker.java index 30959b0c4..5d69ce1f9 100644 --- a/src/main/java/org/opensearch/ad/breaker/ThresholdCircuitBreaker.java +++ b/src/main/java/org/opensearch/timeseries/breaker/ThresholdCircuitBreaker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.breaker; +package org.opensearch.timeseries.breaker; /** * An abstract class for all breakers with threshold. diff --git a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java similarity index 73% rename from src/main/java/org/opensearch/ad/caching/CacheBuffer.java rename to src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java index d9ec0143d..c1130b831 100644 --- a/src/main/java/org/opensearch/ad/caching/CacheBuffer.java +++ b/src/main/java/org/opensearch/timeseries/caching/CacheBuffer.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; @@ -25,273 +25,149 @@ import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.ExpiringState; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.MemoryTracker.Origin; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.util.DateUtils; - -/** - * We use a layered cache to manage active entities’ states. We have a two-level - * cache that stores active entity states in each node. Each detector has its - * dedicated cache that stores ten (dynamically adjustable) entities’ states per - * node. A detector’s hottest entities load their states in the dedicated cache. - * If less than 10 entities use the dedicated cache, the secondary cache can use - * the rest of the free memory available to AD. The secondary cache is a shared - * memory among all detectors for the long tail. The shared cache size is 10% - * heap minus all of the dedicated cache consumed by single-entity and multi-entity - * detectors. The shared cache’s size shrinks as the dedicated cache is filled - * up or more detectors are started. - * - * Implementation-wise, both dedicated cache and shared cache are stored in items - * and minimumCapacity controls the boundary. If items size is equals to or less - * than minimumCapacity, consider items as dedicated cache; otherwise, consider - * top minimumCapacity active entities (last X entities in priorityList) as in dedicated - * cache and all others in shared cache. - */ -public class CacheBuffer implements ExpiringState { +import org.opensearch.timeseries.ExpiringState; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainRequest; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.util.DateUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CacheBuffer & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker> + implements + ExpiringState { + private static final Logger LOG = LogManager.getLogger(CacheBuffer.class); - // max entities to track per detector - private final int MAX_TRACKING_ENTITIES = 1000000; + protected Instant lastUsedTime; + protected final Clock clock; + protected final MemoryTracker memoryTracker; + protected int checkpointIntervalHrs; + protected final Duration modelTtl; + + // max entities to track per detector + protected final int MAX_TRACKING_ENTITIES = 1000000; // the reserved cache size. So no matter how many entities there are, we will // keep the size for minimum capacity entities - private int minimumCapacity; - - // key is model id - private final ConcurrentHashMap> items; + protected int minimumCapacity; // memory consumption per entity - private final long memoryConsumptionPerEntity; - private final MemoryTracker memoryTracker; - private final Duration modelTtl; - private final String detectorId; - private Instant lastUsedTime; - private long reservedBytes; - private final PriorityTracker priorityTracker; - private final Clock clock; - private final CheckpointWriteWorker checkpointWriteQueue; - private final CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected final long memoryConsumptionPerModel; + protected long reservedBytes; + protected final CheckpointWriterType checkpointWriteQueue; + protected final CheckpointMaintainerType checkpointMaintainQueue; + protected final String configId; + protected final Origin origin; + protected final PriorityTracker priorityTracker; + // key is model id + protected final ConcurrentHashMap> items; public CacheBuffer( int minimumCapacity, - long intervalSecs, - long memoryConsumptionPerEntity, - MemoryTracker memoryTracker, Clock clock, + MemoryTracker memoryTracker, + int checkpointIntervalHrs, Duration modelTtl, - String detectorId, - CheckpointWriteWorker checkpointWriteQueue, - CheckpointMaintainWorker checkpointMaintainQueue, - int checkpointIntervalHrs + long memoryConsumptionPerEntity, + CheckpointWriterType checkpointWriteQueue, + CheckpointMaintainerType checkpointMaintainQueue, + String configId, + long intervalSecs, + Origin origin ) { - this.memoryConsumptionPerEntity = memoryConsumptionPerEntity; - setMinimumCapacity(minimumCapacity); - - this.items = new ConcurrentHashMap<>(); - this.memoryTracker = memoryTracker; - - this.modelTtl = modelTtl; - this.detectorId = detectorId; this.lastUsedTime = clock.instant(); - this.clock = clock; - this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.memoryTracker = memoryTracker; + setCheckpointIntervalHrs(checkpointIntervalHrs); + this.modelTtl = modelTtl; + setMinimumCapacity(minimumCapacity); + this.memoryConsumptionPerModel = memoryConsumptionPerEntity; this.checkpointWriteQueue = checkpointWriteQueue; this.checkpointMaintainQueue = checkpointMaintainQueue; - setCheckpointIntervalHrs(checkpointIntervalHrs); - } - - /** - * Update step at period t_k: - * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, - * and n is the period. - * @param entityModelId model Id - */ - private void update(String entityModelId) { - priorityTracker.updatePriority(entityModelId); - - Instant now = clock.instant(); - items.get(entityModelId).setLastUsedTime(now); - lastUsedTime = now; - } - - /** - * Insert the model state associated with a model Id to the cache - * @param entityModelId the model Id - * @param value the ModelState - */ - public void put(String entityModelId, ModelState value) { - // race conditions can happen between the put and one of the following operations: - // remove: not a problem as it is unlikely we are removing and putting the same thing - // maintenance: not a problem as we are unlikely to maintain an entry that's not - // already in the cache - // clear: not a problem as we are releasing memory in MemoryTracker. - // The newly added one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put from other threads: not a problem as the entry is associated with - // entityModelId and our put is idempotent - put(entityModelId, value, value.getPriority()); + this.configId = configId; + this.origin = origin; + this.priorityTracker = new PriorityTracker(clock, intervalSecs, clock.instant().getEpochSecond(), MAX_TRACKING_ENTITIES); + this.items = new ConcurrentHashMap<>(); } - /** - * Insert the model state associated with a model Id to the cache. Update priority. - * @param entityModelId the model Id - * @param value the ModelState - * @param priority the priority - */ - private void put(String entityModelId, ModelState value, float priority) { - ModelState contentNode = items.get(entityModelId); - if (contentNode == null) { - priorityTracker.addPriority(entityModelId, priority); - items.put(entityModelId, value); - Instant now = clock.instant(); - value.setLastUsedTime(now); - lastUsedTime = now; - // shared cache empty means we are consuming reserved cache. - // Since we have already considered them while allocating CacheBuffer, - // skip bookkeeping. - if (!sharedCacheEmpty()) { - memoryTracker.consumeMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); - } - } else { - update(entityModelId); - items.put(entityModelId, value); + public void setMinimumCapacity(int minimumCapacity) { + if (minimumCapacity < 0) { + throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); } + this.minimumCapacity = minimumCapacity; + this.reservedBytes = memoryConsumptionPerModel * minimumCapacity; } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState get(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; - } - update(key); - return node; + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); } - /** - * Retrieve the ModelState associated with the model Id or null if the CacheBuffer - * contains no mapping for the model Id. Compared to get method, the method won't - * increment entity priority. Used in cache buffer maintenance. - * - * @param key the model Id - * @return the Model state to which the specified model Id is mapped, or null - * if this CacheBuffer contains no mapping for the model Id - */ - public ModelState getWithoutUpdatePriority(String key) { - // We can get an item that is to be removed soon due to race condition. - // This is acceptable as it won't cause any corruption and exception. - // And this item is used for scoring one last time. - ModelState node = items.get(key); - if (node == null) { - return null; + public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { + this.checkpointIntervalHrs = checkpointIntervalHrs; + // 0 can cause java.lang.ArithmeticException: / by zero + // negative value is meaningless + if (checkpointIntervalHrs <= 0) { + this.checkpointIntervalHrs = 1; } - return node; } - /** - * - * @return whether there is one item that can be removed from shared cache - */ - public boolean canRemove() { - return !items.isEmpty() && items.size() > minimumCapacity; + public int getCheckpointIntervalHrs() { + return checkpointIntervalHrs; } /** - * remove the smallest priority item. - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove() { - // race conditions can happen between the put and one of the following operations: - // remove from other threads: not a problem. If they remove the same item, - // our method is idempotent. If they remove two different items, - // they don't impact each other. - // maintenance: not a problem as all of the data structures are concurrent. - // Two threads removing the same entry is not a problem. - // clear: not a problem as we are releasing memory in MemoryTracker. - // The removed one loses references and soon GC will collect it. - // We have memory tracking correction to fix incorrect memory usage record. - // put: not a problem as it is unlikely we are removing and putting the same thing - Optional key = priorityTracker.getMinimumPriorityEntityId(); - if (key.isPresent()) { - return remove(key.get()); - } - return null; + * + * @return reserved bytes by the CacheBuffer + */ + public long getReservedBytes() { + return reservedBytes; } /** - * Remove everything associated with the key and make a checkpoint. - * - * @param keyToRemove The key to remove - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove) { - return remove(keyToRemove, true); + * + * @return the estimated number of bytes per entity state + */ + public long getMemoryConsumptionPerModel() { + return memoryConsumptionPerModel; } - /** - * Remove everything associated with the key and make a checkpoint if input specified so. - * - * @param keyToRemove The key to remove - * @param saveCheckpoint Whether saving checkpoint or not - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key - */ - public ModelState remove(String keyToRemove, boolean saveCheckpoint) { - priorityTracker.removePriority(keyToRemove); - - // if shared cache is empty, we are using reserved memory - boolean reserved = sharedCacheEmpty(); - - ModelState valueRemoved = items.remove(keyToRemove); + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; - if (valueRemoved != null) { - if (!reserved) { - // release in shared memory - memoryTracker.releaseMemory(memoryConsumptionPerEntity, false, Origin.HC_DETECTOR); - } + if (obj instanceof CacheBuffer) { + @SuppressWarnings("unchecked") + CacheBuffer other = + (CacheBuffer) obj; - EntityModel modelRemoved = valueRemoved.getModel(); - if (modelRemoved != null) { - if (saveCheckpoint) { - // null model has only samples. For null model we save a checkpoint - // regardless of last checkpoint time. whether If we don't save, - // we throw the new samples and might never be able to initialize the model - boolean isNullModel = !modelRemoved.getTrcf().isPresent(); - checkpointWriteQueue.write(valueRemoved, isNullModel, RequestPriority.MEDIUM); - } + EqualsBuilder equalsBuilder = new EqualsBuilder(); + equalsBuilder.append(configId, other.configId); - modelRemoved.clear(); - } + return equalsBuilder.isEquals(); } + return false; + } - return valueRemoved; + @Override + public int hashCode() { + return new HashCodeBuilder().append(configId).toHashCode(); } - /** - * @return whether dedicated cache is available or not - */ - public boolean dedicatedCacheAvailable() { - return items.size() < minimumCapacity; + public String getConfigId() { + return configId; } /** @@ -302,56 +178,47 @@ public boolean sharedCacheEmpty() { } /** - * - * @return the estimated number of bytes per entity state - */ - public long getMemoryConsumptionPerEntity() { - return memoryConsumptionPerEntity; - } - - /** - * - * If the cache is not full, check if some other items can replace internal entities - * within the same detector. - * - * @param priority another entity's priority - * @return whether one entity can be replaced by another entity with a certain priority - */ - public boolean canReplaceWithinDetector(float priority) { - if (items.isEmpty()) { - return false; + * + * @return bytes consumed in the shared cache by the CacheBuffer + */ + public long getBytesInSharedCache() { + int sharedCacheEntries = items.size() - minimumCapacity; + if (sharedCacheEntries > 0) { + return memoryConsumptionPerModel * sharedCacheEntries; } - Optional> minPriorityItem = priorityTracker.getMinimumPriority(); - return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); + return 0; } /** - * Replace the smallest priority entity with the input entity - * @param entityModelId the Model Id - * @param value the model State - * @return the associated ModelState associated with the key, or null if there - * is no associated ModelState for the key + * Clear associated memory. Used when we are removing an detector. */ - public ModelState replace(String entityModelId, ModelState value) { - ModelState replaced = remove(); - put(entityModelId, value); - return replaced; + public void clear() { + // race conditions can happen between the put and remove/maintenance/put: + // not a problem as we are releasing memory in MemoryTracker. + // The newly added one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + memoryTracker.releaseMemory(getReservedBytes(), true, origin); + if (!sharedCacheEmpty()) { + memoryTracker.releaseMemory(getBytesInSharedCache(), false, origin); + } + items.clear(); + priorityTracker.clearPriority(); } /** * Remove expired state and save checkpoints of existing states * @return removed states */ - public List> maintenance() { + public List> maintenance() { List modelsToSave = new ArrayList<>(); - List> removedStates = new ArrayList<>(); + List> removedStates = new ArrayList<>(); Instant now = clock.instant(); int currentHour = DateUtils.getUTCHourOfDay(now); int currentSlot = currentHour % checkpointIntervalHrs; items.entrySet().stream().forEach(entry -> { String entityModelId = entry.getKey(); try { - ModelState modelState = entry.getValue(); + ModelState modelState = entry.getValue(); if (modelState.getLastUsedTime().plus(modelTtl).isBefore(now)) { // race conditions can happen between the put and one of the following operations: @@ -397,7 +264,7 @@ public List> maintenance() { new CheckpointMaintainRequest( // the request expires when the next maintainance starts System.currentTimeMillis() + modelTtl.toMillis(), - detectorId, + configId, RequestPriority.LOW, entityModelId ) @@ -414,9 +281,97 @@ public List> maintenance() { } /** + * Remove everything associated with the key and make a checkpoint if input specified so. + * + * @param keyToRemove The key to remove + * @param saveCheckpoint Whether saving checkpoint or not + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove(String keyToRemove, boolean saveCheckpoint) { + priorityTracker.removePriority(keyToRemove); + + // if shared cache is empty, we are using reserved memory + boolean reserved = sharedCacheEmpty(); + + ModelState valueRemoved = items.remove(keyToRemove); + + if (valueRemoved != null) { + if (!reserved) { + // release in shared memory + memoryTracker.releaseMemory(memoryConsumptionPerModel, false, origin); + } + + if (saveCheckpoint) { + // null model has only samples. For null model we save a checkpoint + // regardless of last checkpoint time. whether If we don't save, + // we throw the new samples and might never be able to initialize the model + checkpointWriteQueue.write(valueRemoved, valueRemoved.getModel().isEmpty(), RequestPriority.MEDIUM); + } + + valueRemoved.clear(); + } + + return valueRemoved; + } + + /** + * Remove everything associated with the key and make a checkpoint. * - * @return the number of active entities + * @param keyToRemove The key to remove + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key */ + public ModelState remove(String keyToRemove) { + return remove(keyToRemove, true); + } + + public PriorityTracker getPriorityTracker() { + return priorityTracker; + } + + /** + * remove the smallest priority item. + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState remove() { + // race conditions can happen between the put and one of the following operations: + // remove from other threads: not a problem. If they remove the same item, + // our method is idempotent. If they remove two different items, + // they don't impact each other. + // maintenance: not a problem as all of the data structures are concurrent. + // Two threads removing the same entry is not a problem. + // clear: not a problem as we are releasing memory in MemoryTracker. + // The removed one loses references and soon GC will collect it. + // We have memory tracking correction to fix incorrect memory usage record. + // put: not a problem as it is unlikely we are removing and putting the same thing + Optional key = priorityTracker.getMinimumPriorityEntityId(); + if (key.isPresent()) { + return remove(key.get()); + } + return null; + } + + /** + * + * @return whether there is one item that can be removed from shared cache + */ + public boolean canRemove() { + return !items.isEmpty() && items.size() > minimumCapacity; + } + + /** + * @return whether dedicated cache is available or not + */ + public boolean dedicatedCacheAvailable() { + return items.size() < minimumCapacity; + } + + /** + * + * @return the number of active entities + */ public int getActiveEntities() { return items.size(); } @@ -436,7 +391,7 @@ public boolean isActive(String entityModelId) { * @return Last used time of the model */ public long getLastUsedTime(String entityModelId) { - ModelState state = items.get(entityModelId); + ModelState state = items.get(entityModelId); if (state != null) { return state.getLastUsedTime().toEpochMilli(); } @@ -448,105 +403,139 @@ public long getLastUsedTime(String entityModelId) { * @param entityModelId entity Id * @return Get the model of an entity */ - public Optional getModel(String entityModelId) { - return Optional.of(items).map(map -> map.get(entityModelId)).map(state -> state.getModel()); + public ModelState getModelState(String entityModelId) { + // flatMap allows for mapping the inner Optional directly, which results in + // a single Optional instead of a nested Optional>. + return items.get(entityModelId); } /** - * Clear associated memory. Used when we are removing an detector. + * Update step at period t_k: + * new priority = old priority + log(1+e^{\log(g(t_k-L))-old priority}) where g(n) = e^{0.125n}, + * and n is the period. + * @param entityModelId model Id */ - public void clear() { - // race conditions can happen between the put and remove/maintenance/put: - // not a problem as we are releasing memory in MemoryTracker. + private void update(String entityModelId) { + priorityTracker.updatePriority(entityModelId); + + Instant now = clock.instant(); + items.get(entityModelId).setLastUsedTime(now); + lastUsedTime = now; + } + + /** + * Insert the model state associated with a model Id to the cache + * @param entityModelId the model Id + * @param value the ModelState + */ + public void put(String entityModelId, ModelState value) { + // race conditions can happen between the put and one of the following operations: + // remove: not a problem as it is unlikely we are removing and putting the same thing + // maintenance: not a problem as we are unlikely to maintain an entry that's not + // already in the cache + // clear: not a problem as we are releasing memory in MemoryTracker. // The newly added one loses references and soon GC will collect it. // We have memory tracking correction to fix incorrect memory usage record. - memoryTracker.releaseMemory(getReservedBytes(), true, Origin.HC_DETECTOR); - if (!sharedCacheEmpty()) { - memoryTracker.releaseMemory(getBytesInSharedCache(), false, Origin.HC_DETECTOR); + // put from other threads: not a problem as the entry is associated with + // entityModelId and our put is idempotent + put(entityModelId, value, value.getPriority()); + } + + /** + * Insert the model state associated with a model Id to the cache. Update priority. + * @param entityModelId the model Id + * @param value the ModelState + * @param priority the priority + */ + private void put(String entityModelId, ModelState value, float priority) { + ModelState contentNode = items.get(entityModelId); + if (contentNode == null) { + priorityTracker.addPriority(entityModelId, priority); + items.put(entityModelId, value); + Instant now = clock.instant(); + value.setLastUsedTime(now); + lastUsedTime = now; + // shared cache empty means we are consuming reserved cache. + // Since we have already considered them while allocating CacheBuffer, + // skip bookkeeping. + if (!sharedCacheEmpty()) { + memoryTracker.consumeMemory(memoryConsumptionPerModel, false, origin); + } + } else { + update(entityModelId); + items.put(entityModelId, value); } - items.clear(); - priorityTracker.clearPriority(); } /** - * - * @return reserved bytes by the CacheBuffer + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getReservedBytes() { - return reservedBytes; + public ModelState get(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; + } + update(key); + return node; } /** + * Retrieve the ModelState associated with the model Id or null if the CacheBuffer + * contains no mapping for the model Id. Compared to get method, the method won't + * increment entity priority. Used in cache buffer maintenance. * - * @return bytes consumed in the shared cache by the CacheBuffer + * @param key the model Id + * @return the Model state to which the specified model Id is mapped, or null + * if this CacheBuffer contains no mapping for the model Id */ - public long getBytesInSharedCache() { - int sharedCacheEntries = items.size() - minimumCapacity; - if (sharedCacheEntries > 0) { - return memoryConsumptionPerEntity * sharedCacheEntries; + public ModelState getWithoutUpdatePriority(String key) { + // We can get an item that is to be removed soon due to race condition. + // This is acceptable as it won't cause any corruption and exception. + // And this item is used for scoring one last time. + ModelState node = items.get(key); + if (node == null) { + return null; } - return 0; + return node; } - @Override - public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) + /** + * + * If the cache is not full, check if some other items can replace internal entities + * within the same detector. + * + * @param priority another entity's priority + * @return whether one entity can be replaced by another entity with a certain priority + */ + public boolean canReplaceWithinDetector(float priority) { + if (items.isEmpty()) { return false; - if (obj instanceof InitProgressProfile) { - CacheBuffer other = (CacheBuffer) obj; - - EqualsBuilder equalsBuilder = new EqualsBuilder(); - equalsBuilder.append(detectorId, other.detectorId); - - return equalsBuilder.isEquals(); } - return false; - } - - @Override - public int hashCode() { - return new HashCodeBuilder().append(detectorId).toHashCode(); - } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); + Optional> minPriorityItem = priorityTracker.getMinimumPriority(); + return minPriorityItem.isPresent() && priority > minPriorityItem.get().getValue(); } - public String getId() { - return detectorId; + /** + * Replace the smallest priority entity with the input entity + * @param entityModelId the Model Id + * @param value the model State + * @return the associated ModelState associated with the key, or null if there + * is no associated ModelState for the key + */ + public ModelState replace(String entityModelId, ModelState value) { + ModelState replaced = remove(); + put(entityModelId, value); + return replaced; } - public List> getAllModels() { + public List> getAllModelStates() { return items.values().stream().collect(Collectors.toList()); } - - public PriorityTracker getPriorityTracker() { - return priorityTracker; - } - - public void setMinimumCapacity(int minimumCapacity) { - if (minimumCapacity < 0) { - throw new IllegalArgumentException("minimum capacity should be larger than or equal 0"); - } - this.minimumCapacity = minimumCapacity; - this.reservedBytes = memoryConsumptionPerEntity * minimumCapacity; - } - - public void setCheckpointIntervalHrs(int checkpointIntervalHrs) { - this.checkpointIntervalHrs = checkpointIntervalHrs; - // 0 can cause java.lang.ArithmeticException: / by zero - // negative value is meaningless - if (checkpointIntervalHrs <= 0) { - this.checkpointIntervalHrs = 1; - } - } - - public int getCheckpointIntervalHrs() { - return checkpointIntervalHrs; - } } diff --git a/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java new file mode 100644 index 000000000..9b4a53705 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/CacheProvider.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import org.opensearch.common.inject.Provider; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A wrapper to call concrete implementation of caching. Used in transport + * action. Don't use interface because transport action handler constructor + * requires a concrete class as input. + * + */ +public class CacheProvider> + implements + Provider { + private CacheType cache; + + public CacheProvider() { + + } + + @Override + public CacheType get() { + return cache; + } + + public void set(CacheType cache) { + this.cache = cache; + } +} diff --git a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java similarity index 94% rename from src/main/java/org/opensearch/ad/caching/DoorKeeper.java rename to src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java index 96a18d8f6..58af131d1 100644 --- a/src/main/java/org/opensearch/ad/caching/DoorKeeper.java +++ b/src/main/java/org/opensearch/timeseries/caching/DoorKeeper.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; @@ -17,8 +17,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.ExpiringState; -import org.opensearch.ad.MaintenanceState; +import org.opensearch.timeseries.ExpiringState; +import org.opensearch.timeseries.MaintenanceState; import com.google.common.base.Charsets; import com.google.common.hash.BloomFilter; diff --git a/src/main/java/org/opensearch/ad/caching/PriorityCache.java b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java similarity index 69% rename from src/main/java/org/opensearch/ad/caching/PriorityCache.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityCache.java index be8c05397..c7688b6ce 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityCache.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityCache.java @@ -9,16 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEDICATED_CACHE_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -32,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Triple; @@ -39,65 +38,67 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.ActionListener; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.MemoryTracker.Origin; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.ModelProfile; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.settings.ADEnabledSetting; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.ratelimit.CheckpointMaintainWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.DateUtils; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; -public class PriorityCache implements EntityCache { - private final Logger LOG = LogManager.getLogger(PriorityCache.class); +public abstract class PriorityCache & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriterType extends CheckpointWriteWorker, CheckpointMaintainerType extends CheckpointMaintainWorker, CacheBufferType extends CacheBuffer> + implements + TimeSeriesCache { + + private static final Logger LOG = LogManager.getLogger(PriorityCache.class); // detector id -> CacheBuffer, weight based - private final Map activeEnities; - private final CheckpointDao checkpointDao; - private volatile int dedicatedCacheSize; + private final Map activeEnities; + private final CheckpointDaoType checkpointDao; + protected volatile int hcDedicatedCacheSize; // LRU Cache, key is model id - private Cache> inActiveEntities; - private final MemoryTracker memoryTracker; + private Cache> inActiveEntities; + protected final MemoryTracker memoryTracker; private final ReentrantLock maintenanceLock; private final int numberOfTrees; - private final Clock clock; - private final Duration modelTtl; + protected final Clock clock; + protected final Duration modelTtl; // A bloom filter placed in front of inactive entity cache to // filter out unpopular items that are not likely to appear more // than once. Key is detector id private Map doorKeepers; private ThreadPool threadPool; + private String threadPoolName; private Random random; - private CheckpointWriteWorker checkpointWriteQueue; // iterating through all of inactive entities is heavy. We don't want to do // it again and again for no obvious benefits. private Instant lastInActiveEntityMaintenance; protected int maintenanceFreqConstant; - private CheckpointMaintainWorker checkpointMaintainQueue; - private int checkpointIntervalHrs; + protected int checkpointIntervalHrs; + private Origin origin; public PriorityCache( - CheckpointDao checkpointDao, - int dedicatedCacheSize, + CheckpointDaoType checkpointDao, + int hcDedicatedCacheSize, Setting checkpointTtl, int maxInactiveStates, MemoryTracker memoryTracker, @@ -106,22 +107,24 @@ public PriorityCache( ClusterService clusterService, Duration modelTtl, ThreadPool threadPool, - CheckpointWriteWorker checkpointWriteQueue, + String threadPoolName, int maintenanceFreqConstant, - CheckpointMaintainWorker checkpointMaintainQueue, Settings settings, - Setting checkpointSavingFreq + Setting checkpointSavingFreq, + Origin origin, + Setting dedicatedCacheSizeSetting, + Setting modelMaxSizePercent ) { this.checkpointDao = checkpointDao; this.activeEnities = new ConcurrentHashMap<>(); - this.dedicatedCacheSize = dedicatedCacheSize; - clusterService.getClusterSettings().addSettingsUpdateConsumer(DEDICATED_CACHE_SIZE, (it) -> { - this.dedicatedCacheSize = it; - this.setDedicatedCacheSizeListener(); + this.hcDedicatedCacheSize = hcDedicatedCacheSize; + clusterService.getClusterSettings().addSettingsUpdateConsumer(dedicatedCacheSizeSetting, (it) -> { + this.hcDedicatedCacheSize = it; + this.setHCDedicatedCacheSizeListener(); this.tryClearUpMemory(); }, this::validateDedicatedCacheSize); - clusterService.getClusterSettings().addSettingsUpdateConsumer(MODEL_MAX_SIZE_PERCENTAGE, it -> this.tryClearUpMemory()); + clusterService.getClusterSettings().addSettingsUpdateConsumer(modelMaxSizePercent, it -> this.tryClearUpMemory()); this.memoryTracker = memoryTracker; this.maintenanceLock = new ReentrantLock(); @@ -141,37 +144,37 @@ public PriorityCache( ); this.threadPool = threadPool; + this.threadPoolName = threadPoolName; this.random = new Random(42); - this.checkpointWriteQueue = checkpointWriteQueue; this.lastInActiveEntityMaintenance = Instant.MIN; this.maintenanceFreqConstant = maintenanceFreqConstant; - this.checkpointMaintainQueue = checkpointMaintainQueue; this.checkpointIntervalHrs = DateUtils.toDuration(checkpointSavingFreq.get(settings)).toHoursPart(); clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointSavingFreq, it -> { this.checkpointIntervalHrs = DateUtils.toDuration(it).toHoursPart(); this.setCheckpointFreqListener(); }); + this.origin = origin; } @Override - public ModelState get(String modelId, AnomalyDetector detector) { - String detectorId = detector.getId(); - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); - ModelState modelState = buffer.get(modelId); + public ModelState get(String modelId, Config config) { + String configId = config.getId(); + CacheBufferType buffer = computeBufferIfAbsent(config, configId); + ModelState modelState = buffer.get(modelId); // during maintenance period, stop putting new entries if (!maintenanceLock.isLocked() && modelState == null) { - if (ADEnabledSetting.isDoorKeeperInCacheEnabled()) { + if (isDoorKeeperInCacheEnabled()) { DoorKeeper doorKeeper = doorKeepers .computeIfAbsent( - detectorId, + configId, id -> { // reset every 60 intervals return new DoorKeeper( TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION, TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, - detector.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), clock ); } @@ -183,19 +186,14 @@ public ModelState get(String modelId, AnomalyDetector detector) { // this model Id. We have to call isActive method to make sure. Otherwise, // the entity might miss an anomaly result every 60 intervals due to door keeper // reset. - if (!doorKeeper.mightContain(modelId) && !isActive(detectorId, modelId)) { + if (!doorKeeper.mightContain(modelId) && !isActive(configId, modelId)) { doorKeeper.put(modelId); return null; } } try { - ModelState state = inActiveEntities.get(modelId, new Callable>() { - @Override - public ModelState call() { - return new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); - } - }); + ModelState state = inActiveEntities.get(modelId, createInactiveEntityCacheLoader(modelId, configId)); // make sure no model has been stored due to previous race conditions state.setModel(null); @@ -236,7 +234,7 @@ public ModelState call() { return modelState; } - private Optional> getStateFromInactiveEntiiyCache(String modelId) { + private Optional> getStateFromInactiveEntiiyCache(String modelId) { if (modelId == null) { return Optional.empty(); } @@ -246,25 +244,25 @@ private Optional> getStateFromInactiveEntiiyCache(String } @Override - public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate) { - if (toUpdate == null) { + public boolean hostIfPossible(Config config, ModelState toUpdate) { + if (toUpdate == null || toUpdate.getModel() == null) { return false; } String modelId = toUpdate.getModelId(); - String detectorId = toUpdate.getId(); + String detectorId = toUpdate.getConfigId(); if (Strings.isEmpty(modelId) || Strings.isEmpty(detectorId)) { return false; } - CacheBuffer buffer = computeBufferIfAbsent(detector, detectorId); + CacheBufferType buffer = computeBufferIfAbsent(config, detectorId); - Optional> state = getStateFromInactiveEntiiyCache(modelId); + Optional> state = getStateFromInactiveEntiiyCache(modelId); if (false == state.isPresent()) { return false; } - ModelState modelState = state.get(); + ModelState modelState = state.get(); float priority = modelState.getPriority(); @@ -272,13 +270,13 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState toUpdate.setPriority(priority); // current buffer's dedicated cache has free slots or can allocate in shared cache - if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (buffer.dedicatedCacheAvailable() || memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); return true; } - if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + if (memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // buffer.put will call MemoryTracker.consumeMemory buffer.put(modelId, toUpdate); return true; @@ -286,7 +284,7 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState // can replace an entity in the same CacheBuffer living in reserved or shared cache if (buffer.canReplaceWithinDetector(priority)) { - ModelState removed = buffer.replace(modelId, toUpdate); + ModelState removed = buffer.replace(modelId, toUpdate); // null in the case of some other threads have emptied the queue at // the same time so there is nothing to replace if (removed != null) { @@ -298,10 +296,10 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState // If two threads try to remove the same entity and add their own state, the 2nd remove // returns null and only the first one succeeds. float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); - CacheBuffer bufferToRemove = bufferToRemoveEntity.getLeft(); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + CacheBufferType bufferToRemove = bufferToRemoveEntity.getLeft(); String entityModelId = bufferToRemoveEntity.getMiddle(); - ModelState removed = null; + ModelState removed = null; if (bufferToRemove != null && ((removed = bufferToRemove.remove(entityModelId)) != null)) { buffer.put(modelId, toUpdate); addIntoInactiveCache(removed); @@ -311,7 +309,7 @@ public boolean hostIfPossible(AnomalyDetector detector, ModelState return false; } - private void addIntoInactiveCache(ModelState removed) { + private void addIntoInactiveCache(ModelState removed) { if (removed == null) { return; } @@ -321,10 +319,10 @@ private void addIntoInactiveCache(ModelState removed) { inActiveEntities.put(removed.getModelId(), removed); } - private void addEntity(List destination, Entity entity, String detectorId) { + private void addEntity(List destination, Entity entity, String configId) { // It's possible our doorkeepr prevented the entity from entering inactive entities cache if (entity != null) { - Optional modelId = entity.getModelId(detectorId); + Optional modelId = entity.getModelId(configId); if (modelId.isPresent() && inActiveEntities.getIfPresent(modelId.get()) != null) { destination.add(entity); } @@ -335,12 +333,12 @@ private void addEntity(List destination, Entity entity, String detectorI public Pair, List> selectUpdateCandidate( Collection cacheMissEntities, String detectorId, - AnomalyDetector detector + Config detector ) { List hotEntities = new ArrayList<>(); List coldEntities = new ArrayList<>(); - CacheBuffer buffer = activeEnities.get(detectorId); + CacheBufferType buffer = activeEnities.get(detectorId); if (buffer == null) { // don't want to create side-effects by creating a CacheBuffer // In current implementation, this branch is impossible as we call @@ -356,7 +354,7 @@ public Pair, List> selectUpdateCandidate( addEntity(hotEntities, cacheMissEntitiesIter.next(), detectorId); } - while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerEntity())) { + while (cacheMissEntitiesIter.hasNext() && memoryTracker.canAllocate(buffer.getMemoryConsumptionPerModel())) { // can allocate in shared cache // race conditions can happen when multiple threads evaluating this condition. // This is a problem as our AD memory usage is close to full and we put @@ -383,13 +381,13 @@ public Pair, List> selectUpdateCandidate( continue; } - Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> state = getStateFromInactiveEntiiyCache(modelId.get()); if (false == state.isPresent()) { // not even recorded in inActiveEntities yet because of doorKeeper continue; } - ModelState modelState = state.get(); + ModelState modelState = state.get(); float priority = modelState.getPriority(); if (buffer.canReplaceWithinDetector(priority)) { @@ -402,7 +400,7 @@ public Pair, List> selectUpdateCandidate( // record current minimum priority among all detectors to save redundant // scanning of all CacheBuffers - CacheBuffer bufferToRemove = null; + CacheBufferType bufferToRemove = null; float minPriority = Float.MIN_VALUE; // check if we can replace in other CacheBuffer @@ -418,13 +416,13 @@ public Pair, List> selectUpdateCandidate( continue; } - Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); + Optional> inactiveState = getStateFromInactiveEntiiyCache(modelId.get()); if (false == inactiveState.isPresent()) { // empty state should not stand a chance to replace others continue; } - ModelState state = inactiveState.get(); + ModelState state = inactiveState.get(); float priority = state.getPriority(); float scaledPriority = buffer.getPriorityTracker().getScaledPriority(priority); @@ -437,7 +435,7 @@ public Pair, List> selectUpdateCandidate( // Float.MIN_VALUE means we need to re-iterate through all CacheBuffers if (minPriority == Float.MIN_VALUE) { - Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); + Triple bufferToRemoveEntity = canReplaceInSharedCache(buffer, scaledPriority); bufferToRemove = bufferToRemoveEntity.getLeft(); minPriority = bufferToRemoveEntity.getRight(); } @@ -456,33 +454,20 @@ public Pair, List> selectUpdateCandidate( return Pair.of(hotEntities, coldEntities); } - private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detectorId) { - CacheBuffer buffer = activeEnities.get(detectorId); + private CacheBufferType computeBufferIfAbsent(Config config, String configId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { - long requiredBytes = getRequiredMemory(detector, dedicatedCacheSize); + long requiredBytes = getRequiredMemory(config, config.isHighCardinality() ? hcDedicatedCacheSize : 1); if (memoryTracker.canAllocateReserved(requiredBytes)) { - memoryTracker.consumeMemory(requiredBytes, true, Origin.HC_DETECTOR); - long intervalSecs = detector.getIntervalInSeconds(); - - buffer = new CacheBuffer( - dedicatedCacheSize, - intervalSecs, - getRequiredMemory(detector, 1), - memoryTracker, - clock, - modelTtl, - detectorId, - checkpointWriteQueue, - checkpointMaintainQueue, - checkpointIntervalHrs - ); - activeEnities.put(detectorId, buffer); + memoryTracker.consumeMemory(requiredBytes, true, origin); + buffer = createEmptyCacheBuffer(config, getRequiredMemory(config, 1)); + activeEnities.put(configId, buffer); // There can be race conditions between tryClearUpMemory and // activeEntities.put above as tryClearUpMemory accesses activeEnities too. // Put tryClearUpMemory after consumeMemory to prevent that. tryClearUpMemory(); } else { - throw new LimitExceededException(detectorId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); + throw new LimitExceededException(configId, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG); } } @@ -491,20 +476,12 @@ private CacheBuffer computeBufferIfAbsent(AnomalyDetector detector, String detec /** * - * @param detector Detector config accessor + * @param config Detector config accessor * @param numberOfEntity number of entities * @return Memory in bytes required for hosting numberOfEntity entities */ - private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { - int dimension = detector.getEnabledFeatureIds().size() * detector.getShingleSize(); - return numberOfEntity * memoryTracker - .estimateTRCFModelSize( - dimension, - numberOfTrees, - TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, - detector.getShingleSize().intValue(), - true - ); + private long getRequiredMemory(Config config, int numberOfEntity) { + return numberOfEntity * getRequiredMemoryPerEntity(config, memoryTracker, numberOfTrees); } /** @@ -518,12 +495,12 @@ private long getRequiredMemory(AnomalyDetector detector, int numberOfEntity) { * @param candidatePriority the candidate entity's priority * @return the CacheBuffer if we can find a CacheBuffer to make room for the candidate entity */ - private Triple canReplaceInSharedCache(CacheBuffer originBuffer, float candidatePriority) { - CacheBuffer minPriorityBuffer = null; + private Triple canReplaceInSharedCache(CacheBufferType originBuffer, float candidatePriority) { + CacheBufferType minPriorityBuffer = null; float minPriority = candidatePriority; String minPriorityEntityModelId = null; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); if (buffer != originBuffer && buffer.canRemove()) { Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { @@ -548,7 +525,7 @@ private Triple canReplaceInSharedCache(CacheBuffer o private void tryClearUpMemory() { try { if (maintenanceLock.tryLock()) { - threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME).execute(() -> clearMemory()); + threadPool.executor(threadPoolName).execute(() -> clearMemory()); } else { threadPool.schedule(() -> { try { @@ -556,7 +533,7 @@ private void tryClearUpMemory() { } catch (Exception e) { LOG.error("Fail to clear up memory taken by CacheBuffer. Will retry during maintenance."); } - }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME); + }, new TimeValue(random.nextInt(90), TimeUnit.SECONDS), threadPoolName); } } finally { if (maintenanceLock.isHeldByCurrentThread()) { @@ -568,12 +545,12 @@ private void tryClearUpMemory() { private void clearMemory() { recalculateUsedMemory(); long memoryToShed = memoryTracker.memoryToShed(); - PriorityQueue> removalCandiates = null; + PriorityQueue> removalCandiates = null; if (memoryToShed > 0) { // sort the triple in an ascending order of priority removalCandiates = new PriorityQueue<>((x, y) -> Float.compare(x.getLeft(), y.getLeft())); - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); Optional> priorityEntry = buffer.getPriorityTracker().getMinimumScaledPriority(); if (!priorityEntry.isPresent()) { continue; @@ -586,12 +563,12 @@ private void clearMemory() { } while (memoryToShed > 0) { if (false == removalCandiates.isEmpty()) { - Triple toRemove = removalCandiates.poll(); - CacheBuffer minPriorityBuffer = toRemove.getMiddle(); + Triple toRemove = removalCandiates.poll(); + CacheBufferType minPriorityBuffer = toRemove.getMiddle(); String minPriorityEntityModelId = toRemove.getRight(); - ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); - memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerEntity(); + ModelState removed = minPriorityBuffer.remove(minPriorityEntityModelId); + memoryToShed -= minPriorityBuffer.getMemoryConsumptionPerModel(); addIntoInactiveCache(removed); if (minPriorityBuffer.canRemove()) { @@ -616,12 +593,12 @@ private void clearMemory() { private void recalculateUsedMemory() { long reserved = 0; long shared = 0; - for (Map.Entry entry : activeEnities.entrySet()) { - CacheBuffer buffer = entry.getValue(); + for (Map.Entry entry : activeEnities.entrySet()) { + CacheBufferType buffer = entry.getValue(); reserved += buffer.getReservedBytes(); shared += buffer.getBytesInSharedCache(); } - memoryTracker.syncMemoryState(Origin.HC_DETECTOR, reserved + shared, reserved); + memoryTracker.syncMemoryState(origin, reserved + shared, reserved); } /** @@ -637,15 +614,15 @@ public void maintenance() { // clean up memory if we allocate more memory than we should tryClearUpMemory(); activeEnities.entrySet().stream().forEach(cacheBufferEntry -> { - String detectorId = cacheBufferEntry.getKey(); - CacheBuffer cacheBuffer = cacheBufferEntry.getValue(); + String configId = cacheBufferEntry.getKey(); + CacheBufferType cacheBuffer = cacheBufferEntry.getValue(); // remove expired cache buffer if (cacheBuffer.expired(modelTtl)) { - activeEnities.remove(detectorId); + activeEnities.remove(configId); cacheBuffer.clear(); } else { - List> removedStates = cacheBuffer.maintenance(); - for (ModelState state : removedStates) { + List> removedStates = cacheBuffer.maintenance(); + for (ModelState state : removedStates) { addIntoInactiveCache(state); } } @@ -654,11 +631,11 @@ public void maintenance() { maintainInactiveCache(); doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { - String detectorId = doorKeeperEntry.getKey(); + String configId = doorKeeperEntry.getKey(); DoorKeeper doorKeeper = doorKeeperEntry.getValue(); // doorKeeper has its own state ttl if (doorKeeper.expired(null)) { - doorKeepers.remove(detectorId); + doorKeepers.remove(configId); } else { doorKeeper.maintenance(); } @@ -673,19 +650,19 @@ public void maintenance() { /** * Permanently deletes models hosted in memory and persisted in index. * - * @param detectorId id the of the detector for which models are to be permanently deleted + * @param configId id the of the config for which models are to be permanently deleted */ @Override - public void clear(String detectorId) { - if (Strings.isEmpty(detectorId)) { + public void clear(String configId) { + if (Strings.isEmpty(configId)) { return; } - CacheBuffer buffer = activeEnities.remove(detectorId); + CacheBufferType buffer = activeEnities.remove(configId); if (buffer != null) { buffer.clear(); } - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); - doorKeepers.remove(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(configId); + doorKeepers.remove(configId); } /** @@ -695,7 +672,7 @@ public void clear(String detectorId) { */ @Override public int getActiveEntities(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { return cacheBuffer.getActiveEntities(); } @@ -704,13 +681,13 @@ public int getActiveEntities(String detectorId) { /** * Whether an entity is active or not - * @param detectorId The Id of the detector that an entity belongs to + * @param configId The Id of the detector that an entity belongs to * @param entityModelId Entity's Model Id * @return Whether an entity is active or not */ @Override - public boolean isActive(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public boolean isActive(String configId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); if (cacheBuffer != null) { return cacheBuffer.isActive(entityModelId); } @@ -718,32 +695,22 @@ public boolean isActive(String detectorId, String entityModelId) { } @Override - public long getTotalUpdates(String detectorId) { + public long getTotalUpdates(String configId) { return Optional .of(activeEnities) - .map(entities -> entities.get(detectorId)) + .map(entities -> entities.get(configId)) .map(buffer -> buffer.getPriorityTracker().getHighestPriorityEntityId()) .map(entityModelIdOptional -> entityModelIdOptional.get()) - .map(entityModelId -> getTotalUpdates(detectorId, entityModelId)) + .map(entityModelId -> getTotalUpdates(configId, entityModelId)) .orElse(0L); } @Override - public long getTotalUpdates(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null) { - Optional modelOptional = cacheBuffer.getModel(entityModelId); - // TODO: make it work for shingles. samples.size() is not the real shingle - long accumulatedShingles = modelOptional - .flatMap(model -> model.getTrcf()) - .map(trcf -> trcf.getForest()) - .map(rcf -> rcf.getTotalUpdates()) - .orElseGet( - () -> modelOptional.map(model -> model.getSamples()).map(samples -> samples.size()).map(Long::valueOf).orElse(0L) - ); - return accumulatedShingles; - } - return 0L; + public long getTotalUpdates(String configId, String entityModelId) { + return Optional + .ofNullable(activeEnities.get(configId)) + .map(cacheBuffer -> getTotalUpdates(cacheBuffer.getModelState(entityModelId))) + .orElse(0L); } /** @@ -763,24 +730,25 @@ public int getTotalActiveEntities() { * @return list of modelStates */ @Override - public List> getAllModels() { - List> states = new ArrayList<>(); - activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModels())); + public List> getAllModels() { + List> states = new ArrayList<>(); + activeEnities.values().stream().forEach(cacheBuffer -> states.addAll(cacheBuffer.getAllModelStates())); return states; } /** * Gets all of a detector's model sizes hosted on a node * + * @param configId config Id * @return a map of model id to its memory size */ @Override - public Map getModelSize(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public Map getModelSize(String configId) { + CacheBufferType cacheBuffer = activeEnities.get(configId); Map res = new HashMap<>(); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> res.put(entry.getModelId(), size)); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + cacheBuffer.getAllModelStates().forEach(entry -> res.put(entry.getModelId(), size)); } return res; } @@ -799,8 +767,8 @@ public Map getModelSize(String detectorId) { * milliseconds when the entity's state is lastly used. Otherwise, return -1. */ @Override - public long getLastActiveMs(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); + public long getLastActiveTime(String detectorId, String entityModelId) { + CacheBufferType cacheBuffer = activeEnities.get(detectorId); long lastUsedMs = -1; if (cacheBuffer != null) { lastUsedMs = cacheBuffer.getLastUsedTime(entityModelId); @@ -808,7 +776,7 @@ public long getLastActiveMs(String detectorId, String entityModelId) { return lastUsedMs; } } - ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); + ModelState stateInActive = inActiveEntities.getIfPresent(entityModelId); if (stateInActive != null) { lastUsedMs = stateInActive.getLastUsedTime().toEpochMilli(); } @@ -822,7 +790,7 @@ public void releaseMemoryForOpenCircuitBreaker() { tryClearUpMemory(); activeEnities.values().stream().forEach(cacheBuffer -> { if (cacheBuffer.canRemove()) { - ModelState removed = cacheBuffer.remove(); + ModelState removed = cacheBuffer.remove(); addIntoInactiveCache(removed); } }); @@ -838,9 +806,9 @@ private void maintainInactiveCache() { inActiveEntities.cleanUp(); // // make sure no model has been stored due to bugs - for (ModelState state : inActiveEntities.asMap().values()) { - EntityModel model = state.getModel(); - if (model != null && model.getTrcf().isPresent()) { + for (ModelState state : inActiveEntities.asMap().values()) { + Optional modelOptional = state.getModel(); + if (modelOptional.isPresent()) { LOG.warn(new ParameterizedMessage("Inactive entity's model is null: [{}]. Maybe there are bugs.", state.getModelId())); state.setModel(null); } @@ -853,8 +821,8 @@ private void maintainInactiveCache() { * Called when dedicated cache size changes. Will adjust existing cache buffer's * cache size */ - private void setDedicatedCacheSizeListener() { - activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(dedicatedCacheSize)); + private void setHCDedicatedCacheSizeListener() { + activeEnities.values().stream().forEach(cacheBuffer -> cacheBuffer.setMinimumCapacity(hcDedicatedCacheSize)); } private void setCheckpointFreqListener() { @@ -863,20 +831,16 @@ private void setCheckpointFreqListener() { @Override public List getAllModelProfile(String detectorId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - List res = new ArrayList<>(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); if (cacheBuffer != null) { - long size = cacheBuffer.getMemoryConsumptionPerEntity(); - cacheBuffer.getAllModels().forEach(entry -> { - EntityModel model = entry.getModel(); - Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); - } - res.add(new ModelProfile(entry.getModelId(), entity, size)); - }); + long size = cacheBuffer.getMemoryConsumptionPerModel(); + return cacheBuffer + .getAllModelStates() + .stream() + .map(entry -> new ModelProfile(entry.getModelId(), entry.getEntity().orElse(null), size)) + .collect(Collectors.toList()); } - return res; + return Collections.emptyList(); } /** @@ -888,14 +852,14 @@ public List getAllModelProfile(String detectorId) { */ @Override public Optional getModelProfile(String detectorId, String entityModelId) { - CacheBuffer cacheBuffer = activeEnities.get(detectorId); - if (cacheBuffer != null && cacheBuffer.getModel(entityModelId).isPresent()) { - EntityModel model = cacheBuffer.getModel(entityModelId).get(); + CacheBufferType cacheBuffer = activeEnities.get(detectorId); + if (cacheBuffer != null && cacheBuffer.getModelState(entityModelId) != null) { + ModelState modelState = cacheBuffer.getModelState(entityModelId); Entity entity = null; - if (model != null && model.getEntity().isPresent()) { - entity = model.getEntity().get(); + if (modelState != null && modelState.getEntity().isPresent()) { + entity = modelState.getEntity().get(); } - return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerEntity())); + return Optional.of(new ModelProfile(entityModelId, entity, cacheBuffer.getMemoryConsumptionPerModel())); } return Optional.empty(); } @@ -907,11 +871,11 @@ public Optional getModelProfile(String detectorId, String entityMo * @param newDedicatedCacheSize the new dedicated cache size to validate */ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { - if (this.dedicatedCacheSize < newDedicatedCacheSize) { - int delta = newDedicatedCacheSize - this.dedicatedCacheSize; + if (this.hcDedicatedCacheSize < newDedicatedCacheSize) { + int delta = newDedicatedCacheSize - this.hcDedicatedCacheSize; long totalIncreasedBytes = 0; - for (CacheBuffer cacheBuffer : activeEnities.values()) { - totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerEntity() * delta; + for (CacheBufferType cacheBuffer : activeEnities.values()) { + totalIncreasedBytes += cacheBuffer.getMemoryConsumptionPerModel() * delta; } if (false == memoryTracker.canAllocateReserved(totalIncreasedBytes)) { @@ -922,13 +886,13 @@ private void validateDedicatedCacheSize(Integer newDedicatedCacheSize) { /** * Get a model state without incurring priority update. Used in maintenance. - * @param detectorId Detector Id + * @param configId Config Id * @param modelId Model Id * @return Model state */ @Override - public Optional> getForMaintainance(String detectorId, String modelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public Optional> getForMaintainance(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer == null) { return Optional.empty(); } @@ -936,31 +900,31 @@ public Optional> getForMaintainance(String detectorId, S } /** - * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. - * @param detectorId Detector Id - * @param entityModelId Model Id + * Remove model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param modelId Model Id */ @Override - public void removeEntityModel(String detectorId, String entityModelId) { - CacheBuffer buffer = activeEnities.get(detectorId); + public void removeModel(String configId, String modelId) { + CacheBufferType buffer = activeEnities.get(configId); if (buffer != null) { - ModelState removed = null; - if ((removed = buffer.remove(entityModelId, false)) != null) { + ModelState removed = buffer.remove(modelId, false); + if (removed != null) { addIntoInactiveCache(removed); } } checkpointDao .deleteModelCheckpoint( - entityModelId, + modelId, ActionListener .wrap( - r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", entityModelId)), - e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", entityModelId), e) + r -> LOG.debug(new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", modelId)), + e -> LOG.error(new ParameterizedMessage("Failed to delete checkpoint [{}].", modelId), e) ) ); } - private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { + private Cache> createInactiveCache(Duration inactiveEntityTtl, int maxInactiveStates) { return CacheBuilder .newBuilder() .expireAfterAccess(inactiveEntityTtl.toHours(), TimeUnit.HOURS) @@ -968,4 +932,10 @@ private Cache> createInactiveCache(Duration inac .concurrencyLevel(1) .build(); } + + protected abstract Callable> createInactiveEntityCacheLoader(String modelId, String detectorId); + + protected abstract CacheBufferType createEmptyCacheBuffer(Config config, long memoryConsumptionPerEntity); + + protected abstract boolean isDoorKeeperInCacheEnabled(); } diff --git a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java similarity index 97% rename from src/main/java/org/opensearch/ad/caching/PriorityTracker.java rename to src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java index 439d67679..07f2087ec 100644 --- a/src/main/java/org/opensearch/ad/caching/PriorityTracker.java +++ b/src/main/java/org/opensearch/timeseries/caching/PriorityTracker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.caching; +package org.opensearch.timeseries.caching; import java.time.Clock; import java.util.AbstractMap.SimpleImmutableEntry; @@ -236,7 +236,7 @@ public void updatePriority(String entityId) { * @param entityId Entity Id * @param priority priority */ - protected void addPriority(String entityId, float priority) { + public void addPriority(String entityId, float priority) { PriorityNode node = new PriorityNode(entityId, priority); key2Priority.put(entityId, node); priorityList.add(node); @@ -260,7 +260,7 @@ private void adjustSizeIfRequired() { * Remove an entity in the tracker * @param entityId Entity Id */ - protected void removePriority(String entityId) { + public void removePriority(String entityId) { // remove if the key matches; priority does not matter priorityList.remove(new PriorityNode(entityId, 0)); key2Priority.remove(entityId); @@ -269,7 +269,7 @@ protected void removePriority(String entityId) { /** * Remove all of entities */ - protected void clearPriority() { + public void clearPriority() { key2Priority.clear(); priorityList.clear(); } @@ -292,7 +292,7 @@ protected void clearPriority() { * * @return new priority */ - float getUpdatedPriority(float oldPriority) { + public float getUpdatedPriority(float oldPriority) { long increment = computeWeightedPriorityIncrement(); oldPriority += Math.log(1 + Math.exp(increment - oldPriority)); // if overflow happens, using the most recent decayed count instead. @@ -319,7 +319,7 @@ float getUpdatedPriority(float oldPriority) { * @param currentPriority Current priority * @return the scaled priority */ - float getScaledPriority(float currentPriority) { + public float getScaledPriority(float currentPriority) { return currentPriority - computeWeightedPriorityIncrement(); } diff --git a/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java new file mode 100644 index 000000000..fa5b0c1eb --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/caching/TimeSeriesCache.java @@ -0,0 +1,187 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.caching; + +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.timeseries.AnalysisModelSize; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public interface TimeSeriesCache extends MaintenanceState, CleanState, AnalysisModelSize { + /** + * + * @param config Analysis config + * @param toUpdate Model state candidate + * @return if we can host the given model state + */ + boolean hostIfPossible(Config config, ModelState toUpdate); + + /** + * Get a model state without incurring priority update or load from state from disk. Used in maintenance. + * @param configId Config Id + * @param modelId Model Id + * @return Model state + */ + Optional> getForMaintainance(String configId, String modelId); + + /** + * Get the ModelState associated with the modelId. May or may not load the + * ModelState depending on the underlying cache's memory consumption. + * + * @param modelId Model Id + * @param config config accessor + * @return the ModelState associated with the config or null if no cached item + * for the config + */ + ModelState get(String modelId, Config config); + + /** + * Whether an entity is active or not + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity model Id + * @return Whether an entity is active or not + */ + boolean isActive(String configId, String entityModelId); + + /** + * Get total updates of the config's most active entity's RCF model. + * + * @param configId detector id + * @return RCF model total updates of most active entity. + */ + long getTotalUpdates(String configId); + + /** + * Get RCF model total updates of specific entity + * + * @param configId config id + * @param entityModelId entity model id + * @return RCF model total updates of specific entity. + */ + long getTotalUpdates(String configId, String entityModelId); + + /** + * Gets modelStates of all model hosted on a node + * + * @return list of modelStates + */ + List> getAllModels(); + + /** + * Get the number of active entities of a config + * @param configId Config Id + * @return The number of active entities + */ + int getActiveEntities(String configId); + + /** + * + * @return total active entities in the cache + */ + int getTotalActiveEntities(); + + /** + * Return when the last active time of an entity's state. + * + * If the entity's state is active in the cache, the value indicates when the cache + * is lastly accessed (get/put). If the entity's state is inactive in the cache, + * the value indicates when the cache state is created or when the entity is evicted + * from active entity cache. + * + * @param configId The Id of the config that an entity belongs to + * @param entityModelId Entity's Model Id + * @return if the entity is in the cache, return the timestamp in epoch + * milliseconds when the entity's state is lastly used. Otherwise, return -1. + */ + long getLastActiveTime(String configId, String entityModelId); + + /** + * Release memory when memory circuit breaker is open + */ + void releaseMemoryForOpenCircuitBreaker(); + + /** + * Select candidate entities for which we can load models + * @param cacheMissEntities Cache miss entities + * @param configId Config Id + * @param config Config object + * @return A list of entities that are admitted into the cache as a result of the + * update and the left-over entities + */ + Pair, List> selectUpdateCandidate(Collection cacheMissEntities, String configId, Config config); + + /** + * + * @param configId Detector Id + * @return a detector's model information + */ + List getAllModelProfile(String configId); + + /** + * Gets an entity's model sizes + * + * @param configId Detector Id + * @param entityModelId Entity's model Id + * @return the entity's memory size + */ + Optional getModelProfile(String configId, String entityModelId); + + /** + * Remove entity model from active entity buffer and delete checkpoint. Used to clean corrupted model. + * @param configId config Id + * @param entityModelId Model Id + */ + void removeModel(String configId, String entityModelId); + + /** + * + * @param config Detector config accessor + * @param memoryTracker memory tracker + * @param numberOfTrees number of trees + * @return Memory in bytes required for hosting one entity model + */ + default long getRequiredMemoryPerEntity(Config config, MemoryTracker memoryTracker, int numberOfTrees) { + int dimension = config.getEnabledFeatureIds().size() * config.getShingleSize(); + return memoryTracker + .estimateTRCFModelSize( + dimension, + numberOfTrees, + TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO, + config.getShingleSize().intValue(), + true + ); + } + + default long getTotalUpdates(ModelState modelState) { + // TODO: make it work for shingles. samples.size() is not the real shingle + long accumulatedShingles = Optional + .ofNullable(modelState) + .flatMap(model -> model.getModel()) + .map(trcf -> trcf.getForest()) + .map(rcf -> rcf.getTotalUpdates()) + .orElseGet( + () -> Optional + .ofNullable(modelState) + .map(model -> model.getSamples()) + .map(samples -> samples.size()) + .map(Long::valueOf) + .orElse(0L) + ); + return accumulatedShingles; + } +} diff --git a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java similarity index 90% rename from src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java rename to src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java index 62702d15c..2156d211c 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADDataMigrator.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ADDataMigrator.java @@ -9,15 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD; -import static org.opensearch.ad.model.ADTaskType.taskTypeToString; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_DETECTOR_UPPER_LIMIT; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; @@ -39,12 +37,9 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorInternalState; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.XContentFactory; @@ -60,6 +55,10 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.util.ExceptionUtil; /** * Migrate AD data to support backward compatibility. @@ -137,13 +136,13 @@ public void migrateDetectorInternalStateToRealtimeTask() { logger.info("No anomaly detector job found, no need to migrate"); return; } - ConcurrentLinkedQueue detectorJobs = new ConcurrentLinkedQueue<>(); + ConcurrentLinkedQueue detectorJobs = new ConcurrentLinkedQueue<>(); Iterator iterator = r.getHits().iterator(); while (iterator.hasNext()) { SearchHit searchHit = iterator.next(); try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser); + Job job = Job.parse(parser); detectorJobs.add(job); } catch (IOException e) { logger.error("Fail to parse AD job " + searchHit.getId(), e); @@ -168,8 +167,8 @@ public void migrateDetectorInternalStateToRealtimeTask() { * @param detectorJobs realtime AD jobs * @param backfillAllJob backfill task for all realtime job or not */ - public void backfillRealtimeTask(ConcurrentLinkedQueue detectorJobs, boolean backfillAllJob) { - AnomalyDetectorJob job = detectorJobs.poll(); + public void backfillRealtimeTask(ConcurrentLinkedQueue detectorJobs, boolean backfillAllJob) { + Job job = detectorJobs.poll(); if (job == null) { logger.info("AD data migration done."); if (backfillAllJob) { @@ -203,19 +202,19 @@ public void backfillRealtimeTask(ConcurrentLinkedQueue detec } private void checkIfRealtimeTaskExistsAndBackfill( - AnomalyDetectorJob job, + Job job, ExecutorFunction createRealtimeTaskFunction, - ConcurrentLinkedQueue detectorJobs, + ConcurrentLinkedQueue detectorJobs, boolean migrateAll ) { String jobId = job.getName(); BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, jobId)); if (job.isEnabled()) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, true)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); } - query.filter(new TermsQueryBuilder(TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(ADTaskType.REALTIME_TASK_TYPES))); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(1); SearchRequest searchRequest = new SearchRequest(DETECTION_STATE_INDEX).source(searchSourceBuilder); client.search(searchRequest, ActionListener.wrap(r -> { @@ -233,20 +232,15 @@ private void checkIfRealtimeTaskExistsAndBackfill( })); } - private void createRealtimeADTask( - AnomalyDetectorJob job, - String error, - ConcurrentLinkedQueue detectorJobs, - boolean migrateAll - ) { + private void createRealtimeADTask(Job job, String error, ConcurrentLinkedQueue detectorJobs, boolean migrateAll) { client.get(new GetRequest(CommonName.CONFIG_INDEX, job.getName()), ActionListener.wrap(r -> { if (r != null && r.isExists()) { try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); AnomalyDetector detector = AnomalyDetector.parse(parser, r.getId()); ADTaskType taskType = detector.isHighCardinality() - ? ADTaskType.REALTIME_HC_DETECTOR - : ADTaskType.REALTIME_SINGLE_ENTITY; + ? ADTaskType.AD_REALTIME_HC_DETECTOR + : ADTaskType.AD_REALTIME_SINGLE_STREAM; Instant now = Instant.now(); String userName = job.getUser() != null ? job.getUser().getName() : null; ADTask adTask = new ADTask.Builder() @@ -258,7 +252,7 @@ private void createRealtimeADTask( .executionStartTime(now) .taskProgress(0.0f) .initProgress(0.0f) - .state(ADTaskState.CREATED.name()) + .state(TaskState.CREATED.name()) .lastUpdateTime(now) .startedBy(userName) .coordinatingNode(null) diff --git a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java similarity index 82% rename from src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java index fd00d9c22..7fc5a5716 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADClusterEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterEventListener.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.util.concurrent.Semaphore; @@ -23,8 +23,8 @@ import org.opensearch.common.inject.Inject; import org.opensearch.gateway.GatewayService; -public class ADClusterEventListener implements ClusterStateListener { - private static final Logger LOG = LogManager.getLogger(ADClusterEventListener.class); +public class ClusterEventListener implements ClusterStateListener { + private static final Logger LOG = LogManager.getLogger(ClusterEventListener.class); static final String NOT_RECOVERED_MSG = "Cluster is not recovered yet."; static final String IN_PROGRESS_MSG = "Cluster state change in progress, return."; static final String NODE_CHANGED_MSG = "Cluster node changed"; @@ -34,7 +34,7 @@ public class ADClusterEventListener implements ClusterStateListener { private final ClusterService clusterService; @Inject - public ADClusterEventListener(ClusterService clusterService, HashRing hashRing) { + public ClusterEventListener(ClusterService clusterService, HashRing hashRing) { this.clusterService = clusterService; this.clusterService.addListener(this); this.hashRing = hashRing; @@ -55,16 +55,13 @@ public void clusterChanged(ClusterChangedEvent event) { } try { - // Init AD version hash ring as early as possible. Some test case may fail as AD + // Init version hash ring as early as possible. Some test case may fail as AD // version hash ring not initialized when test run. if (!hashRing.isHashRingInited()) { hashRing .buildCircles( ActionListener - .wrap( - r -> LOG.info("Init AD version hash ring successfully"), - e -> LOG.error("Failed to init AD version hash ring") - ) + .wrap(r -> LOG.info("Init version hash ring successfully"), e -> LOG.error("Failed to init version hash ring")) ); } Delta delta = event.nodesDelta(); @@ -80,7 +77,7 @@ public void clusterChanged(ClusterChangedEvent event) { ActionListener .wrap( hasRingBuildDone -> { LOG.info("Hash ring build result: {}", hasRingBuildDone); }, - e -> { LOG.error("Failed updating AD version hash ring", e); } + e -> { LOG.error("Failed updating version hash ring", e); } ), () -> inProgress.release() ) diff --git a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java similarity index 97% rename from src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java rename to src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java index 9cf1dd905..1ca9d5f1e 100644 --- a/src/main/java/org/opensearch/ad/cluster/ClusterManagerEventListener.java +++ b/src/main/java/org/opensearch/timeseries/cluster/ClusterManagerEventListener.java @@ -9,15 +9,13 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; import org.opensearch.ad.cluster.diskcleanup.IndexCleanup; import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.DateUtils; import org.opensearch.client.Client; import org.opensearch.cluster.LocalNodeClusterManagerListener; import org.opensearch.cluster.service.ClusterService; @@ -27,6 +25,8 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.DateUtils; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.annotations.VisibleForTesting; diff --git a/src/main/java/org/opensearch/ad/cluster/DailyCron.java b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java similarity index 97% rename from src/main/java/org/opensearch/ad/cluster/DailyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/DailyCron.java index e2b2b8808..ca73b504f 100644 --- a/src/main/java/org/opensearch/ad/cluster/DailyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/DailyCron.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import java.time.Clock; import java.time.Duration; @@ -19,12 +19,12 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.IndicesOptions; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.ClientUtil; @Deprecated public class DailyCron implements Runnable { diff --git a/src/main/java/org/opensearch/ad/cluster/HashRing.java b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java similarity index 75% rename from src/main/java/org/opensearch/ad/cluster/HashRing.java rename to src/main/java/org/opensearch/timeseries/cluster/HashRing.java index 3e6ba0b37..e14d50054 100644 --- a/src/main/java/org/opensearch/ad/cluster/HashRing.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HashRing.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.time.Clock; import java.util.ArrayList; @@ -36,8 +36,7 @@ import org.opensearch.action.admin.cluster.node.info.NodeInfo; import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest; import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -52,6 +51,7 @@ import org.opensearch.plugins.PluginInfo; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.Sets; @@ -69,11 +69,11 @@ public class HashRing { // Semaphore to control only 1 thread can build AD hash ring. private Semaphore buildHashRingSemaphore; - // This field is to track AD version of all nodes. - // Key: node id; Value: AD node info - private Map nodeAdVersions; - // This field records AD version hash ring in realtime way. Historical detection will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field is to track time series plugin version of all nodes. + // Key: node id; Value: node info + private Map nodeVersions; + // This field records time series version hash ring in realtime way. Historical detection will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circles; // Track if hash ring inited or not. If not inited, the first clusterManager event will try to init it. private AtomicBoolean hashRingInited; @@ -82,8 +82,8 @@ public class HashRing { private long lastUpdateForRealtimeAD; // Cool down period before next hash ring rebuild. We need this as realtime AD needs stable hash ring. private volatile TimeValue coolDownPeriodForRealtimeAD; - // This field records AD version hash ring with cooldown period. Realtime job will use this hash ring. - // Key: AD version; Value: hash ring which only contains eligible data nodes + // This field records time series version hash ring with cooldown period. Realtime job will use this hash ring. + // Key: time series version; Value: hash ring which only contains eligible data nodes private TreeMap> circlesForRealtimeAD; // Record node change event. Will check if there is node change event when rebuild AD hash ring with @@ -95,7 +95,7 @@ public class HashRing { private final ADDataMigrator dataMigrator; private final Clock clock; private final Client client; - private final ModelManager modelManager; + private final ADModelManager modelManager; public HashRing( DiscoveryNodeFilterer nodeFilter, @@ -104,19 +104,19 @@ public HashRing( Client client, ClusterService clusterService, ADDataMigrator dataMigrator, - ModelManager modelManager + ADModelManager modelManager ) { this.nodeFilter = nodeFilter; this.buildHashRingSemaphore = new Semaphore(1); this.clock = clock; - this.coolDownPeriodForRealtimeAD = COOLDOWN_MINUTES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(COOLDOWN_MINUTES, it -> coolDownPeriodForRealtimeAD = it); + this.coolDownPeriodForRealtimeAD = AD_COOLDOWN_MINUTES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_COOLDOWN_MINUTES, it -> coolDownPeriodForRealtimeAD = it); this.lastUpdateForRealtimeAD = 0; this.client = client; this.clusterService = clusterService; this.dataMigrator = dataMigrator; - this.nodeAdVersions = new ConcurrentHashMap<>(); + this.nodeVersions = new ConcurrentHashMap<>(); this.circles = new TreeMap<>(); this.circlesForRealtimeAD = new TreeMap<>(); this.hashRingInited = new AtomicBoolean(false); @@ -129,17 +129,17 @@ public boolean isHashRingInited() { } /** - * Build AD version based circles with discovery node delta change. Listen to clusterManager event in - * {@link ADClusterEventListener#clusterChanged(ClusterChangedEvent)}. + * Build version based circles with discovery node delta change. Listen to clusterManager event in + * {@link ClusterEventListener#clusterChanged(ClusterChangedEvent)}. * Will remove the removed nodes from cache and send request to newly added nodes to get their - * plugin information; then add new nodes to AD version hash ring. + * plugin information; then add new nodes to version hash ring. * * @param delta discovery node delta change * @param listener action listener */ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener listener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't build hash ring for node delta event."); + LOG.info("hash ring change is in progress. Can't build hash ring for node delta event."); listener.onResponse(false); return; } @@ -151,14 +151,14 @@ public void buildCircles(DiscoveryNodes.Delta delta, ActionListener lis } /** - * Build AD version based circles by comparing with all eligible data nodes. + * Build version based circles by comparing with all eligible data nodes. * 1. Remove nodes which are not eligible now; - * 2. Add nodes which are not in AD version circles. + * 2. Add nodes which are not in version circles. * @param actionListener action listener */ public void buildCircles(ActionListener actionListener) { if (!buildHashRingSemaphore.tryAcquire()) { - LOG.info("AD version hash ring change is in progress. Can't rebuild hash ring."); + LOG.info("hash ring change is in progress. Can't rebuild hash ring."); actionListener.onResponse(false); return; } @@ -167,39 +167,35 @@ public void buildCircles(ActionListener actionListener) { for (DiscoveryNode node : allNodes) { nodeIds.add(node.getId()); } - Set currentNodeIds = nodeAdVersions.keySet(); + Set currentNodeIds = nodeVersions.keySet(); Set removedNodeIds = Sets.difference(currentNodeIds, nodeIds); Set addedNodeIds = Sets.difference(nodeIds, currentNodeIds); buildCircles(removedNodeIds, addedNodeIds, actionListener); } - public void buildCirclesForRealtimeAD() { + public void buildCirclesForRealtime() { if (nodeChangeEvents.isEmpty()) { return; } buildCircles( - ActionListener - .wrap( - r -> { LOG.debug("build circles on AD versions successfully"); }, - e -> { LOG.error("Failed to build circles on AD versions", e); } - ) + ActionListener.wrap(r -> { LOG.debug("build circles successfully"); }, e -> { LOG.error("Failed to build circles", e); }) ); } /** - * Build AD version hash ring. - * 1. Delete removed nodes from AD version hash ring. - * 2. Add new nodes to AD version hash ring + * Build version hash ring. + * 1. Delete removed nodes from version hash ring. + * 2. Add new nodes to version hash ring * - * If fail to acquire semaphore to update AD version hash ring, will return false to + * If fail to acquire semaphore to update version hash ring, will return false to * action listener; otherwise will return true. The "true" response just mean we got * semaphore and finished rebuilding hash ring, but the hash ring may stay the same. * Hash ring changed or not depends on if "removedNodeIds" or "addedNodeIds" is empty. * * We use different way to build hash ring for realtime job and historical analysis - * 1. For historical analysis,if node removed, we remove it immediately from adVersionCircles - * to avoid new AD task routes to it. If new node added, we add it immediately to adVersionCircles - * to make load more balanced and speed up AD task running. + * 1. For historical analysis,if node removed, we remove it immediately from version circles + * to avoid new task routes to it. If new node added, we add it immediately to version circles + * to make load more balanced and speed up task running. * 2. For realtime job, we don't record which node running detector's model partition. We just * use hash ring to get owning node. If we rebuild hash ring frequently, realtime job may get * different owning node and need to restore model on new owning node. If that happens a lot, @@ -209,7 +205,7 @@ public void buildCirclesForRealtimeAD() { * and still send RCF request to it. If new node added during cooldown period, realtime job won't * choose it as model partition owning node, thus we may have skewed load on data nodes. * - * [Important!]: When you call this function, make sure you TRY ACQUIRE adVersionCircleInProgress first. + * [Important!]: When you call this function, make sure you TRY ACQUIRE buildHashRingSemaphore first. * Check {@link HashRing#buildCircles(ActionListener)} and * {@link HashRing#buildCircles(DiscoveryNodes.Delta, ActionListener)} * @@ -226,10 +222,10 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (removedNodeIds != null && removedNodeIds.size() > 0) { LOG.info("Node removed: {}", Arrays.toString(removedNodeIds.toArray(new String[0]))); for (String nodeId : removedNodeIds) { - ADNodeInfo nodeInfo = nodeAdVersions.remove(nodeId); + TimeSeriesNodeInfo nodeInfo = nodeVersions.remove(nodeId); if (nodeInfo != null && nodeInfo.isEligibleDataNode()) { - removeNodeFromCircles(nodeId, nodeInfo.getAdVersion()); - LOG.info("Remove data node from AD version hash ring: {}", nodeId); + removeNodeFromCircles(nodeId, nodeInfo.getVersion()); + LOG.info("Remove data node from version hash ring: {}", nodeId); } } } @@ -238,12 +234,12 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, if (addedNodeIds != null) { allAddedNodes.addAll(addedNodeIds); } - if (!nodeAdVersions.containsKey(localNode.getId())) { + if (!nodeVersions.containsKey(localNode.getId())) { allAddedNodes.add(localNode.getId()); } if (allAddedNodes.size() == 0) { actionListener.onResponse(true); - // rebuild AD version hash ring with cooldown. + // rebuild version hash ring with cooldown. rebuildCirclesForRealtimeAD(); buildHashRingSemaphore.release(); return; @@ -268,15 +264,16 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } TreeMap circle = null; for (PluginInfo pluginInfo : plugins.getPluginInfos()) { + // if (AD_PLUGIN_NAME.equals(pluginInfo.getName()) || AD_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { if (CommonName.TIME_SERIES_PLUGIN_NAME.equals(pluginInfo.getName()) || CommonName.TIME_SERIES_PLUGIN_NAME_FOR_TEST.equals(pluginInfo.getName())) { - Version version = ADVersionUtil.fromString(pluginInfo.getVersion()); + Version version = VersionUtil.fromString(pluginInfo.getVersion()); boolean eligibleNode = nodeFilter.isEligibleNode(curNode); if (eligibleNode) { circle = circles.computeIfAbsent(version, key -> new TreeMap<>()); - LOG.info("Add data node to AD version hash ring: {}", curNode.getId()); + LOG.info("Add data node to version hash ring: {}", curNode.getId()); } - nodeAdVersions.put(curNode.getId(), new ADNodeInfo(version, eligibleNode)); + nodeVersions.put(curNode.getId(), new TimeSeriesNodeInfo(version, eligibleNode)); break; } } @@ -287,15 +284,15 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, } } } - LOG.info("All nodes with known AD version: {}", nodeAdVersions); + LOG.info("All nodes with known version: {}", nodeVersions); - // rebuild AD version hash ring with cooldown after all new node added. + // rebuild version hash ring with cooldown after all new node added. rebuildCirclesForRealtimeAD(); if (!dataMigrator.isMigrated() && circles.size() > 0) { - // Find owning node with highest AD version to make sure the data migration logic be compatible to - // latest AD version when upgrade. - Optional owningNode = getOwningNodeWithHighestAdVersion(DEFAULT_HASH_RING_MODEL_ID); + // Find owning node with highest version to make sure the data migration logic be compatible to + // latest version when upgrade. + Optional owningNode = getOwningNodeWithHighestVersion(DEFAULT_HASH_RING_MODEL_ID); String localNodeId = localNode.getId(); if (owningNode.isPresent() && localNodeId.equals(owningNode.get().getId())) { dataMigrator.migrateData(); @@ -309,18 +306,18 @@ private void buildCircles(Set removedNodeIds, Set addedNodeIds, }, e -> { buildHashRingSemaphore.release(); actionListener.onFailure(e); - LOG.error("Fail to get node info to build AD version hash ring", e); + LOG.error("Fail to get node info to build hash ring", e); })); } catch (Exception e) { - LOG.error("Failed to build AD version circles", e); + LOG.error("Failed to build circles", e); buildHashRingSemaphore.release(); actionListener.onFailure(e); } } - private void removeNodeFromCircles(String nodeId, Version adVersion) { - if (adVersion != null) { - TreeMap circle = this.circles.get(adVersion); + private void removeNodeFromCircles(String nodeId, Version version) { + if (version != null) { + TreeMap circle = this.circles.get(version); List deleted = new ArrayList<>(); for (Map.Entry entry : circle.entrySet()) { if (entry.getValue().getId().equals(nodeId)) { @@ -328,7 +325,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { } } if (deleted.size() == circle.size()) { - circles.remove(adVersion); + circles.remove(version); } else { for (Integer key : deleted) { circle.remove(key); @@ -340,7 +337,7 @@ private void removeNodeFromCircles(String nodeId, Version adVersion) { private void rebuildCirclesForRealtimeAD() { // Check if it's eligible to rebuild hash ring with cooldown if (eligibleToRebuildCirclesForRealtimeAD()) { - LOG.info("Rebuild AD hash ring for realtime AD with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); + LOG.info("Rebuild hash ring for realtime with cooldown, nodeChangeEvents size {}", nodeChangeEvents.size()); int size = nodeChangeEvents.size(); TreeMap> newCircles = new TreeMap<>(); for (Map.Entry> entry : circles.entrySet()) { @@ -348,17 +345,17 @@ private void rebuildCirclesForRealtimeAD() { } circlesForRealtimeAD = newCircles; lastUpdateForRealtimeAD = clock.millis(); - LOG.info("Build AD version hash ring successfully"); + LOG.info("Build version hash ring successfully"); String localNodeId = clusterService.localNode().getId(); Set modelIds = modelManager.getAllModelIds(); for (String modelId : modelIds) { - Optional node = getOwningNodeWithSameLocalAdVersionForRealtimeAD(modelId); + Optional node = getOwningNodeWithSameLocalVersionForRealtime(modelId); if (node.isPresent() && !node.get().getId().equals(localNodeId)) { LOG.info(REMOVE_MODEL_MSG + " {}", modelId); modelManager .stopModel( // stopModel will clear model cache - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId), + SingleStreamModelIdMapper.getConfigIdForModelId(modelId), modelId, ActionListener .wrap( @@ -370,7 +367,7 @@ private void rebuildCirclesForRealtimeAD() { } // It's possible that multiple threads add new event to nodeChangeEvents, // but this is the only place to consume/poll the event and there is only - // one thread poll it as we are using adVersionCircleInProgress semaphore(1) + // one thread poll it as we are using buildHashRingSemaphore // to control only 1 thread build hash ring. while (size-- > 0) { Boolean poll = nodeChangeEvents.poll(); @@ -387,7 +384,7 @@ private void rebuildCirclesForRealtimeAD() { * 1. There is node change event not consumed, and * 2. Have passed cool down period from last hash ring update time. * - * Check {@link org.opensearch.ad.settings.AnomalyDetectorSettings#COOLDOWN_MINUTES} about + * Check {@link org.opensearch.ad.settings.AnomalyDetectorSettings#AD_COOLDOWN_MINUTES} about * cool down settings. * * Why we need to wait for some cooldown period before rebuilding hash ring? @@ -416,71 +413,71 @@ protected boolean eligibleToRebuildCirclesForRealtimeAD() { } /** - * Get owning node with highest AD version circle. + * Get owning node with highest version circle. * @param modelId model id * @return owning node */ - public Optional getOwningNodeWithHighestAdVersion(String modelId) { + public Optional getOwningNodeWithHighestVersion(String modelId) { int modelHash = Murmur3HashFunction.hash(modelId); Map.Entry> versionTreeMapEntry = circles.lastEntry(); if (versionTreeMapEntry == null) { return Optional.empty(); } - TreeMap adVersionCircle = versionTreeMapEntry.getValue(); - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = versionTreeMapEntry.getValue(); + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } /** - * Get owning node with same AD version of local node. + * Get owning node with same version of local node. * @param modelId model id * @param function consumer function * @param listener action listener * @param listener response type */ - public void buildAndGetOwningNodeWithSameLocalAdVersion( + public void buildAndGetOwningNodeWithSameLocalVersion( String modelId, Consumer> function, ActionListener listener ) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, false); function.accept(owningNode); }, e -> listener.onFailure(e))); } - public Optional getOwningNodeWithSameLocalAdVersionForRealtimeAD(String modelId) { + public Optional getOwningNodeWithSameLocalVersionForRealtime(String modelId) { try { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Optional owningNode = getOwningNodeWithSameAdVersionDirectly(modelId, adVersion, true); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Optional owningNode = getOwningNodeWithSameVersionDirectly(modelId, version, true); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return owningNode; } catch (Exception e) { - LOG.error("Failed to get owning node with same local AD version", e); + LOG.error("Failed to get owning node with same local time series version", e); return Optional.empty(); } } - private Optional getOwningNodeWithSameAdVersionDirectly(String modelId, Version adVersion, boolean forRealtime) { + private Optional getOwningNodeWithSameVersionDirectly(String modelId, Version version, boolean forRealtime) { int modelHash = Murmur3HashFunction.hash(modelId); - TreeMap adVersionCircle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); - if (adVersionCircle != null) { - Map.Entry entry = adVersionCircle.higherEntry(modelHash); - return Optional.ofNullable(Optional.ofNullable(entry).orElse(adVersionCircle.firstEntry())).map(x -> x.getValue()); + TreeMap versionCircle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); + if (versionCircle != null) { + Map.Entry entry = versionCircle.higherEntry(modelHash); + return Optional.ofNullable(Optional.ofNullable(entry).orElse(versionCircle.firstEntry())).map(x -> x.getValue()); } return Optional.empty(); } - public void getNodesWithSameLocalAdVersion(Consumer function, ActionListener listener) { + public void getNodesWithSameLocalVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(updated -> { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); - if (!nodeAdVersions.containsKey(localNode.getId())) { + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); + if (!nodeVersions.containsKey(localNode.getId())) { nodes.add(localNode); } // Make sure listener return in function @@ -488,17 +485,17 @@ public void getNodesWithSameLocalAdVersion(Consumer functio }, e -> listener.onFailure(e))); } - public DiscoveryNode[] getNodesWithSameLocalAdVersion() { + public DiscoveryNode[] getNodesWithSameLocalVersion() { DiscoveryNode localNode = clusterService.localNode(); - Version adVersion = nodeAdVersions.containsKey(localNode.getId()) ? getAdVersion(localNode.getId()) : Version.CURRENT; - Set nodes = getNodesWithSameAdVersion(adVersion, false); + Version version = nodeVersions.containsKey(localNode.getId()) ? getVersion(localNode.getId()) : Version.CURRENT; + Set nodes = getNodesWithSameVersion(version, false); // rebuild hash ring - buildCirclesForRealtimeAD(); + buildCirclesForRealtime(); return nodes.toArray(new DiscoveryNode[0]); } - protected Set getNodesWithSameAdVersion(Version adVersion, boolean forRealtime) { - TreeMap circle = forRealtime ? circlesForRealtimeAD.get(adVersion) : circles.get(adVersion); + protected Set getNodesWithSameVersion(Version version, boolean forRealtime) { + TreeMap circle = forRealtime ? circlesForRealtimeAD.get(version) : circles.get(version); Set nodeIds = new HashSet<>(); Set nodes = new HashSet<>(); if (circle == null) { @@ -515,13 +512,13 @@ protected Set getNodesWithSameAdVersion(Version adVersion, boolea } /** - * Get AD version. + * Get time series version. * @param nodeId node id - * @return AD version + * @return version */ - public Version getAdVersion(String nodeId) { - ADNodeInfo adNodeInfo = nodeAdVersions.get(nodeId); - return adNodeInfo == null ? null : adNodeInfo.getAdVersion(); + public Version getVersion(String nodeId) { + TimeSeriesNodeInfo nodeInfo = nodeVersions.get(nodeId); + return nodeInfo == null ? null : nodeInfo.getVersion(); } /** @@ -565,17 +562,17 @@ private String getIpAddress(TransportAddress address) { } /** - * Get all eligible data nodes whose AD versions are known in AD version based hash ring. + * Get all eligible data nodes whose time series versions are known in hash ring. * @param function consumer function * @param listener action listener * @param action listener response type */ - public void getAllEligibleDataNodesWithKnownAdVersion(Consumer function, ActionListener listener) { + public void getAllEligibleDataNodesWithKnownVersion(Consumer function, ActionListener listener) { buildCircles(ActionListener.wrap(r -> { DiscoveryNode[] eligibleDataNodes = nodeFilter.getEligibleDataNodes(); List allNodes = new ArrayList<>(); for (DiscoveryNode node : eligibleDataNodes) { - if (nodeAdVersions.containsKey(node.getId())) { + if (nodeVersions.containsKey(node.getId())) { allNodes.add(node); } } diff --git a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java similarity index 98% rename from src/main/java/org/opensearch/ad/cluster/HourlyCron.java rename to src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java index a81156bb0..7c566de70 100644 --- a/src/main/java/org/opensearch/ad/cluster/HourlyCron.java +++ b/src/main/java/org/opensearch/timeseries/cluster/HourlyCron.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; diff --git a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java similarity index 70% rename from src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java rename to src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java index e438623d5..f67d663ae 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADNodeInfo.java +++ b/src/main/java/org/opensearch/timeseries/cluster/TimeSeriesNodeInfo.java @@ -9,25 +9,25 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; /** - * This class records AD version of nodes and whether node is eligible data node to run AD. + * This class records time series plugin version of nodes and whether node is eligible data node to run time series analysis. */ -public class ADNodeInfo { - // AD plugin version +public class TimeSeriesNodeInfo { + // time series plugin version private Version adVersion; // Is node eligible to run AD. private boolean isEligibleDataNode; - public ADNodeInfo(Version version, boolean isEligibleDataNode) { + public TimeSeriesNodeInfo(Version version, boolean isEligibleDataNode) { this.adVersion = version; this.isEligibleDataNode = isEligibleDataNode; } - public Version getAdVersion() { + public Version getVersion() { return adVersion; } diff --git a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java similarity index 95% rename from src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java rename to src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java index 7e880de66..8d506732d 100644 --- a/src/main/java/org/opensearch/ad/cluster/ADVersionUtil.java +++ b/src/main/java/org/opensearch/timeseries/cluster/VersionUtil.java @@ -9,12 +9,12 @@ * GitHub history for details. */ -package org.opensearch.ad.cluster; +package org.opensearch.timeseries.cluster; import org.opensearch.Version; import org.opensearch.timeseries.constant.CommonName; -public class ADVersionUtil { +public class VersionUtil { public static final int VERSION_SEGMENTS = 3; diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java index 393248237..42a4ae717 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonMessages.java @@ -85,4 +85,19 @@ public static String getTooManyCategoricalFieldErr(int limit) { public static final String BUG_RESPONSE = "We might have bugs."; public static final String MEMORY_LIMIT_EXCEEDED_ERR_MSG = "Models memory usage exceeds our limit."; + // ====================================== + // transport + // ====================================== + public static final String CONFIG_ID_MISSING_MSG = "config ID is missing"; + public static final String MODEL_ID_MISSING_MSG = "model ID is missing"; + + // ====================================== + // task + // ====================================== + public static String CAN_NOT_FIND_LATEST_TASK = "can't find latest task"; + + // ====================================== + // Job + // ====================================== + public static String CONFIG_IS_RUNNING = "Config is already running"; } diff --git a/src/main/java/org/opensearch/timeseries/constant/CommonName.java b/src/main/java/org/opensearch/timeseries/constant/CommonName.java index 0b997ea5d..060204298 100644 --- a/src/main/java/org/opensearch/timeseries/constant/CommonName.java +++ b/src/main/java/org/opensearch/timeseries/constant/CommonName.java @@ -113,4 +113,9 @@ public class CommonName { public static final String TIME_SERIES_PLUGIN_NAME = "opensearch-time-series-analytics"; public static final String TIME_SERIES_PLUGIN_NAME_FOR_TEST = "org.opensearch.timeseries.TimeSeriesAnalyticsPlugin"; public static final String TIME_SERIES_PLUGIN_VERSION_FOR_TEST = "NA"; + + // ====================================== + // Profile name + // ====================================== + public static final String CATEGORICAL_FIELD = "category_field"; } diff --git a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java similarity index 99% rename from src/main/java/org/opensearch/ad/feature/AbstractRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java index 886dbcbc4..5f2609ed5 100644 --- a/src/main/java/org/opensearch/ad/feature/AbstractRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/AbstractRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.Iterator; diff --git a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java similarity index 93% rename from src/main/java/org/opensearch/ad/feature/CompositeRetriever.java rename to src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java index d41bdf76e..2674d8516 100644 --- a/src/main/java/org/opensearch/ad/feature/CompositeRetriever.java +++ b/src/main/java/org/opensearch/timeseries/feature/CompositeRetriever.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.io.IOException; import java.time.Clock; @@ -28,8 +28,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.service.ClusterService; @@ -46,9 +44,12 @@ import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; /** * @@ -65,7 +66,7 @@ public class CompositeRetriever extends AbstractRetriever { private final long dataStartEpoch; private final long dataEndEpoch; - private final AnomalyDetector anomalyDetector; + private final Config config; private final NamedXContentRegistry xContent; private final Client client; private final SecurityClientUtil clientUtil; @@ -77,11 +78,12 @@ public class CompositeRetriever extends AbstractRetriever { private Clock clock; private IndexNameExpressionResolver indexNameExpressionResolver; private ClusterService clusterService; + private AnalysisType context; public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config config, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -91,11 +93,12 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this.dataStartEpoch = dataStartEpoch; this.dataEndEpoch = dataEndEpoch; - this.anomalyDetector = anomalyDetector; + this.config = config; this.xContent = xContent; this.client = client; this.clientUtil = clientUtil; @@ -106,13 +109,14 @@ public CompositeRetriever( this.clock = clock; this.indexNameExpressionResolver = indexNameExpressionResolver; this.clusterService = clusterService; + this.context = context; } // a constructor that provide default value of clock public CompositeRetriever( long dataStartEpoch, long dataEndEpoch, - AnomalyDetector anomalyDetector, + Config anomalyDetector, NamedXContentRegistry xContent, Client client, SecurityClientUtil clientUtil, @@ -121,7 +125,8 @@ public CompositeRetriever( int maxEntitiesPerInterval, int pageSize, IndexNameExpressionResolver indexNameExpressionResolver, - ClusterService clusterService + ClusterService clusterService, + AnalysisType context ) { this( dataStartEpoch, @@ -136,7 +141,8 @@ public CompositeRetriever( maxEntitiesPerInterval, pageSize, indexNameExpressionResolver, - clusterService + clusterService, + context ); } @@ -146,21 +152,21 @@ public CompositeRetriever( * detector definition */ public PageIterator iterator() throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(anomalyDetector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .gte(dataStartEpoch) .lt(dataEndEpoch) .format("epoch_millis"); - BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(anomalyDetector.getFilterQuery()).filter(rangeQuery); + BoolQueryBuilder internalFilterQuery = new BoolQueryBuilder().filter(config.getFilterQuery()).filter(rangeQuery); // multiple categorical fields are supported CompositeAggregationBuilder composite = AggregationBuilders .composite( AGG_NAME_COMP, - anomalyDetector.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) + config.getCategoryFields().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList()) ) .size(pageSize); - for (Feature feature : anomalyDetector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = ParseUtils .parseAggregators(feature.getAggregation().toString(), xContent, feature.getId()); composite.subAggregation(internalAgg.getAggregatorFactories().iterator().next()); @@ -200,7 +206,7 @@ public void next(ActionListener listener) { // inject user role while searching. - SearchRequest searchRequest = new SearchRequest(anomalyDetector.getIndices().toArray(new String[0]), source); + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0]), source); final ActionListener searchResponseListener = new ActionListener() { @Override public void onResponse(SearchResponse response) { @@ -218,8 +224,9 @@ public void onFailure(Exception e) { .asyncRequestWithInjectedSecurity( searchRequest, client::search, - anomalyDetector.getId(), + config.getId(), client, + context, searchResponseListener ); } @@ -289,7 +296,7 @@ private Page analyzePage(SearchResponse response) { } */ for (Bucket bucket : composite.getBuckets()) { - Optional featureValues = parseBucket(bucket, anomalyDetector.getEnabledFeatureIds()); + Optional featureValues = parseBucket(bucket, config.getEnabledFeatureIds()); // bucket.getKey() returns a map of categorical field like "host" and its value like "server_1" if (featureValues.isPresent() && bucket.getKey() != null) { results.put(Entity.createEntityByReordering(bucket.getKey()), featureValues.get()); @@ -333,7 +340,7 @@ Optional getComposite(SearchResponse response) { // such index // [blah]","index":"blah","resource.id":"blah","resource.type":"index_or_alias","index_uuid":"_na_"},"status":404}% if (response == null || response.getAggregations() == null) { - List sourceIndices = anomalyDetector.getIndices(); + List sourceIndices = config.getIndices(); String[] concreteIndices = indexNameExpressionResolver .concreteIndexNames(clusterService.state(), IndicesOptions.lenientExpandOpen(), sourceIndices.toArray(new String[0])); if (concreteIndices.length == 0) { diff --git a/src/main/java/org/opensearch/ad/feature/FeatureManager.java b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java similarity index 74% rename from src/main/java/org/opensearch/ad/feature/FeatureManager.java rename to src/main/java/org/opensearch/timeseries/feature/FeatureManager.java index f6fd8ded0..b32835267 100644 --- a/src/main/java/org/opensearch/ad/feature/FeatureManager.java +++ b/src/main/java/org/opensearch/timeseries/feature/FeatureManager.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import static java.util.Arrays.copyOfRange; import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; @@ -23,30 +23,42 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.Deque; +import java.util.Iterator; import java.util.LinkedList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.Queue; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.stream.DoubleStream; import java.util.stream.IntStream; import java.util.stream.LongStream; import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.CleanState; import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.forecast.model.Forecaster; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.CleanState; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.util.DataUtil; /** * A facade managing feature data operations and buffers. @@ -55,15 +67,13 @@ public class FeatureManager implements CleanState { private static final Logger logger = LogManager.getLogger(FeatureManager.class); - // Each anomaly detector has a queue of data points with timestamps (in epoch milliseconds). + // Each single-stream analysis has a queue of data points with timestamps (in epoch milliseconds). private final Map>>> detectorIdsToTimeShingles; private final SearchFeatureDao searchFeatureDao; private final Imputer imputer; private final Clock clock; - private final int maxTrainSamples; - private final int maxSampleStride; private final int trainSampleTimeRangeInHours; private final int minTrainSamples; private final double maxMissingPointsRate; @@ -96,8 +106,6 @@ public FeatureManager( SearchFeatureDao searchFeatureDao, Imputer imputer, Clock clock, - int maxTrainSamples, - int maxSampleStride, int trainSampleTimeRangeInHours, int minTrainSamples, double maxMissingPointsRate, @@ -111,8 +119,6 @@ public FeatureManager( this.searchFeatureDao = searchFeatureDao; this.imputer = imputer; this.clock = clock; - this.maxTrainSamples = maxTrainSamples; - this.maxSampleStride = maxSampleStride; this.trainSampleTimeRangeInHours = trainSampleTimeRangeInHours; this.minTrainSamples = minTrainSamples; this.maxMissingPointsRate = maxMissingPointsRate; @@ -156,7 +162,7 @@ public void getCurrentFeatures(AnomalyDetector detector, long startTime, long en if (missingRanges.size() > 0) { try { - searchFeatureDao.getFeatureSamplesForPeriods(detector, missingRanges, ActionListener.wrap(points -> { + searchFeatureDao.getFeatureSamplesForPeriods(detector, missingRanges, AnalysisType.AD, ActionListener.wrap(points -> { for (int i = 0; i < points.size(); i++) { Optional point = points.get(i); long rangeEndTime = missingRanges.get(i).getValue(); @@ -172,8 +178,45 @@ public void getCurrentFeatures(AnomalyDetector detector, long startTime, long en } } + public void getCurrentFeatures(Forecaster forecaster, long startTime, long endTime, ActionListener listener) { + List> missingRanges = Collections.singletonList(new SimpleImmutableEntry<>(startTime, endTime)); + try { + searchFeatureDao + .getFeatureSamplesForPeriods( + forecaster, + missingRanges, + AnalysisType.FORECAST, + ActionListener + .wrap( + points -> { + // we only have one point + if (points.size() == 1) { + Optional point = points.get(0); + listener.onResponse(new SinglePointFeatures(point, Optional.empty())); + } else { + listener.onResponse(new SinglePointFeatures(Optional.empty(), Optional.empty())); + } + }, + listener::onFailure + ) + ); + } catch (IOException e) { + listener.onFailure(new EndRunException(forecaster.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, true)); + } + } + + public void getCurrentFeatures(Config config, long startTime, long endTime, ActionListener listener) { + if (config instanceof AnomalyDetector) { + getCurrentFeatures((AnomalyDetector) config, startTime, endTime, listener); + } else if (config instanceof Forecaster) { + getCurrentFeatures((Forecaster) config, startTime, endTime, listener); + } else { + throw new UnsupportedOperationException(String.format(Locale.ROOT, "config type %s is not supported.", config.getClass())); + } + } + private List> getMissingRangesInShingle( - AnomalyDetector detector, + Config detector, Map>> featuresMap, long endTime ) { @@ -205,7 +248,7 @@ private List> getMissingRangesInShingle( * @param listener onResponse is called with unprocessed features and processed features for the current data point. */ private void updateUnprocessedFeatures( - AnomalyDetector detector, + Config detector, Deque>> shingle, Map>> featuresMap, long endTime, @@ -219,17 +262,19 @@ private void updateUnprocessedFeatures( listener.onResponse(getProcessedFeatures(shingle, detector, endTime)); } - private double[][] filterAndFill(Deque>> shingle, long endTime, AnomalyDetector detector) { - int shingleSize = detector.getShingleSize(); + private double[][] filterAndFill(Deque>> shingle, long endTime, Config config) { + double[][] result = null; + + int shingleSize = config.getShingleSize(); Deque>> filteredShingle = shingle .stream() .filter(e -> e.getValue().isPresent()) .collect(Collectors.toCollection(ArrayDeque::new)); - double[][] result = null; + if (filteredShingle.size() >= shingleSize - getMaxMissingPoints(shingleSize)) { // Imputes missing data points with the values of neighboring data points. - long maxMillisecondsDifference = maxNeighborDistance * detector.getIntervalInMilliseconds(); - result = getNearbyPointsForShingle(detector, filteredShingle, endTime, maxMillisecondsDifference) + long maxMillisecondsDifference = maxNeighborDistance * config.getIntervalInMilliseconds(); + result = getNearbyPointsForShingle(config, filteredShingle, endTime, maxMillisecondsDifference) .map(e -> e.getValue().getValue().orElse(null)) .filter(d -> d != null) .toArray(double[][]::new); @@ -238,6 +283,7 @@ private double[][] filterAndFill(Deque>> shingle, result = null; } } + return result; } @@ -252,7 +298,7 @@ private double[][] filterAndFill(Deque>> shingle, * point value. */ private Stream>>> getNearbyPointsForShingle( - AnomalyDetector detector, + Config detector, Deque>> shingle, long endTime, long maxMillisecondsDifference @@ -306,6 +352,7 @@ private void getColdStartSamples(Optional latest, AnomalyDetector detector .getFeatureSamplesForPeriods( detector, sampleRanges, + AnalysisType.AD, new ThreadedActionListener<>(logger, threadPool, adThreadPoolName, getFeaturesListener, false) ); } catch (IOException e) { @@ -545,7 +592,14 @@ void getPreviewSamplesInRangesForEntity( ActionListener>, double[][]>> listener ) throws IOException { searchFeatureDao - .getColdStartSamplesForPeriods(detector, sampleRanges, entity, true, getSamplesRangesListener(sampleRanges, listener)); + .getColdStartSamplesForPeriods( + detector, + sampleRanges, + Optional.ofNullable(entity), + true, + AnalysisType.AD, + getSamplesRangesListener(sampleRanges, listener) + ); } private ActionListener>> getSamplesRangesListener( @@ -577,7 +631,8 @@ void getSamplesForRanges( List> sampleRanges, ActionListener>, double[][]>> listener ) throws IOException { - searchFeatureDao.getFeatureSamplesForPeriods(detector, sampleRanges, getSamplesRangesListener(sampleRanges, listener)); + searchFeatureDao + .getFeatureSamplesForPeriods(detector, sampleRanges, AnalysisType.AD, getSamplesRangesListener(sampleRanges, listener)); } /** @@ -677,11 +732,7 @@ public SinglePointFeatures getShingledFeatureForHistoricalAnalysis( return getProcessedFeatures(shingle, detector, endTime); } - private SinglePointFeatures getProcessedFeatures( - Deque>> shingle, - AnomalyDetector detector, - long endTime - ) { + private SinglePointFeatures getProcessedFeatures(Deque>> shingle, Config detector, long endTime) { int shingleSize = detector.getShingleSize(); Optional currentPoint = shingle.peekLast().getValue(); return new SinglePointFeatures( @@ -694,4 +745,168 @@ private SinglePointFeatures getProcessedFeatures( ); } + /** + * Extract sample array from samples and currentUnprocessed. Impute if necessary. + * Whether to use the provided samples is subject to the time ordering of lastProcessed + * and the provided samples. We throw away unProcessedSamples or currentUnprocessed if + * they are older than lastProcessed. + * + * @param config analysis config. + * @param unProcessedSamples unprocessed Samples stored in memory + * @param lastProcessed Last processed sample. + * @param currentUnprocessed current unprocessed sample. + * @return Continuous samples with possible imputations. + */ + public Pair getContinuousSamples( + Config config, + Deque unProcessedSamples, + Sample lastProcessed, + Sample currentUnprocessed + ) { + Deque samples = new ArrayDeque<>(); + if (lastProcessed != null) { + if (unProcessedSamples != null && !unProcessedSamples.isEmpty()) { + Sample lastElement = unProcessedSamples.getLast(); + if (lastElement != null && lastElement.getDataEndTime().compareTo(lastProcessed.getDataEndTime()) > 0) { + samples.add(lastProcessed); + samples.addAll(unProcessedSamples); + } + } else { + samples.add(lastProcessed); + } + } + + if (currentUnprocessed != null) { + samples.add(currentUnprocessed); + } + + if (samples.isEmpty()) { + return Pair.of(new double[0][0], lastProcessed); + } else { + return removeLastSeenSample(getContinuousSamples(config, samples), lastProcessed); + } + } + + /** + * Remove the first sample since it is used before and included for interpolation's purpose + * @param res input data and sample pair + * @param previousLastSeenSample Last seen sample + * @return input without the first sample we have seen + */ + private Pair removeLastSeenSample(Pair res, Sample previousLastSeenSample) { + double[][] values = res.getKey(); + + if (previousLastSeenSample != null && values.length > 1) { + return Pair.of(Arrays.copyOfRange(values, 1, values.length), res.getValue()); + } else if (values.length > 0) { + return Pair.of(values, res.getValue()); + } + + return Pair.of(new double[0][0], res.getValue()); + } + + /** + * Extract samples from the input queue. Impute if necessary. + * + * @param config analysis config. + * @param samples Samples accumulated from previous job runs. + * @return Continuous samples with possible imputation and last seen sample. When samples is empty, return + * empty double array and empty last seen sample. + */ + public Pair getContinuousSamples(Config config, Queue samples) { + // To allow for small time variations/delays in running the config. + long maxMillisecondsDifference = config.getIntervalInMilliseconds() / 2; + + TreeMap search = new TreeMap<>(); + // Iterate over the sample queue using an Iterator. + // The Iterator interface provides a way to iterate over the elements of a queue + // in FIFO order + Iterator iterator = samples.iterator(); + long startTimeMillis = 0; + Sample lastElement = null; + while (iterator.hasNext()) { + lastElement = iterator.next(); + long dataEndTimeMillis = lastElement.getDataEndTime().toEpochMilli(); + if (startTimeMillis == 0) { + startTimeMillis = dataEndTimeMillis; + } + double[] valueList = lastElement.getValueList(); + search.put(dataEndTimeMillis, valueList); + } + + if (startTimeMillis == 0 || lastElement == null) { + return Pair.of(new double[0][0], new Sample()); + } + + long endTimeMillis = lastElement.getDataEndTime().toEpochMilli(); + + // There can be small time variations/delays in running the analysis. + // Training data adjusted using end time and interval so that the end time + // of each sample has equal distance. This would help finding the missing + // data range and apply interpolation. + // A map of entries, where the key is the computed millisecond timestamp + // associated with an interval in the training data, and the value is an entry + // that contains the actual timestamp of the data point and an optional data + // point value. + List adjustedDataEndTime = getFullTrainingDataEndTimes(endTimeMillis, config.getIntervalInMilliseconds(), startTimeMillis); + Map> adjustedTrainingData = adjustedDataEndTime.stream().map(t -> { + Optional> after = Optional.ofNullable(search.ceilingEntry(t)); + Optional> before = Optional.ofNullable(search.floorEntry(t)); + return after + .filter(a -> Math.abs(t - a.getKey()) <= before.map(b -> Math.abs(t - b.getKey())).orElse(Long.MAX_VALUE)) + .map(Optional::of) + .orElse(before) + // training data not within the max difference range will be filtered out and the corresponding t is Optional.empty and + // later filtered out as well + .filter(e -> Math.abs(t - e.getKey()) < maxMillisecondsDifference) + .map(e -> new SimpleImmutableEntry<>(t, e)); + }) + .filter(Optional::isPresent) + .map(Optional::get) + .collect( + Collectors + .toMap( + Entry::getKey, // Key mapper + Entry::getValue, // Value mapper + (v1, v2) -> v1, // Merge function + TreeMap::new + ) // Map implementation to order the entries by key value + ); + + // convert from long to int as we don't expect a huge number of samples + int totalNumSamples = adjustedDataEndTime.size(); + int numEnabledFeatures = config.getEnabledFeatureIds().size(); + double[][] trainingData = new double[totalNumSamples][numEnabledFeatures]; + + Iterator adjustedEndTimeIterator = adjustedDataEndTime.iterator(); + for (int index = 0; adjustedEndTimeIterator.hasNext(); index++) { + long time = adjustedEndTimeIterator.next(); + Entry entry = adjustedTrainingData.get(time); + if (entry != null) { + // the order of the elements in the Stream is the same as the order of the elements in the List entry.getValue() + trainingData[index] = entry.getValue(); + } else { + // create an array of Double.NaN + trainingData[index] = DoubleStream.generate(() -> Double.NaN).limit(numEnabledFeatures).toArray(); + } + } + + Imputer imputer = config.getImputer(); + return Pair.of(DataUtil.ltrim(imputer.impute(trainingData, totalNumSamples)), lastElement); + } + + /** + * + * @param endTime End time of the stream + * @param intervalMilli interval between returned time + * @param startTime Start time of the stream + * @return a list of epoch timestamps from endTime with interval intervalMilli. The stream should stop when the number is earlier than startTime. + */ + private List getFullTrainingDataEndTimes(long endTime, long intervalMilli, long startTime) { + return LongStream + .iterate(startTime, i -> i + intervalMilli) + .takeWhile(i -> i <= endTime) + .boxed() // Convert LongStream to Stream + .collect(Collectors.toList()); // Collect to List + } } diff --git a/src/main/java/org/opensearch/ad/feature/Features.java b/src/main/java/org/opensearch/timeseries/feature/Features.java similarity index 98% rename from src/main/java/org/opensearch/ad/feature/Features.java rename to src/main/java/org/opensearch/timeseries/feature/Features.java index de347b78f..13cefc1d8 100644 --- a/src/main/java/org/opensearch/ad/feature/Features.java +++ b/src/main/java/org/opensearch/timeseries/feature/Features.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Arrays; import java.util.List; diff --git a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java similarity index 95% rename from src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java rename to src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java index 557e98fd7..c277d1cfb 100644 --- a/src/main/java/org/opensearch/ad/feature/SearchFeatureDao.java +++ b/src/main/java/org/opensearch/timeseries/feature/SearchFeatureDao.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import static org.apache.commons.math3.linear.MatrixUtils.createRealMatrix; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.PREVIEW_TIMEOUT_IN_MILLIS; import static org.opensearch.timeseries.util.ParseUtils.batchFeatureQuery; @@ -40,7 +40,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -65,12 +64,15 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortOrder; +import org.opensearch.timeseries.AnalysisType; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; /** * DAO for features from search. @@ -117,7 +119,7 @@ public SearchFeatureDao( if (clusterService != null) { clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ENTITIES_FOR_PREVIEW, it -> this.maxEntitiesForPreview = it); - clusterService.getClusterSettings().addSettingsUpdateConsumer(PAGE_SIZE, it -> this.pageSize = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_PAGE_SIZE, it -> this.pageSize = it); } this.minimumDocCountForPreview = minimumDocCount; this.previewTimeoutInMilliseconds = previewTimeoutInMilliseconds; @@ -155,7 +157,7 @@ public SearchFeatureDao( minimumDocCount, Clock.systemUTC(), MAX_ENTITIES_FOR_PREVIEW.get(settings), - PAGE_SIZE.get(settings), + AD_PAGE_SIZE.get(settings), PREVIEW_TIMEOUT_IN_MILLIS ); } @@ -181,6 +183,7 @@ public void getLatestDataTime(AnomalyDetector detector, ActionListener> listener) { + public void getMinDataTime(Config config, Optional entity, AnalysisType context, ActionListener> listener) { BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery(); - for (TermQueryBuilder term : entity.getTermQueryBuilders()) { - internalFilterQuery.filter(term); + if (entity.isPresent()) { + for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .query(internalFilterQuery) - .aggregation(AggregationBuilders.min(AGG_NAME_MIN).field(detector.getTimeField())) + .aggregation(AggregationBuilders.min(AGG_NAME_MIN).field(config.getTimeField())) .trackTotalHits(false) .size(0); - SearchRequest searchRequest = new SearchRequest().indices(detector.getIndices().toArray(new String[0])).source(searchSourceBuilder); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); final ActionListener searchResponseListener = ActionListener .wrap(response -> { listener.onResponse(parseMinDataTime(response)); }, listener::onFailure); // inject user role while searching. @@ -494,8 +501,9 @@ public void getEntityMinDataTime(AnomalyDetector detector, Entity entity, Action .asyncRequestWithInjectedSecurity( searchRequest, client::search, - detector.getId(), + config.getId(), client, + context, searchResponseListener ); } @@ -529,6 +537,7 @@ public void getFeaturesForPeriod(AnomalyDetector detector, long startTime, long client::search, detector.getId(), client, + AnalysisType.AD, searchResponseListener ); } @@ -556,6 +565,7 @@ public void getFeaturesForPeriodByBatch( client::search, detector.getId(), client, + AnalysisType.AD, searchResponseListener ); } @@ -583,24 +593,24 @@ public Optional parseResponse(SearchResponse response, List fe * * Sampled features are not true features. They are intended to be approximate results produced at low costs. * - * @param detector info about the indices, documents, feature query + * @param config info about the indices, documents, feature query * @param ranges list of time ranges * @param listener handle approximate features for the time ranges * @throws IOException if a user gives wrong query input when defining a detector */ public void getFeatureSamplesForPeriods( - AnomalyDetector detector, + Config config, List> ranges, + AnalysisType context, ActionListener>> listener ) throws IOException { - SearchRequest request = createPreviewSearchRequest(detector, ranges); + SearchRequest request = createPreviewSearchRequest(config, ranges); final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { listener.onResponse(Collections.emptyList()); return; } - listener .onResponse( aggs @@ -608,7 +618,7 @@ public void getFeatureSamplesForPeriods( .stream() .filter(InternalDateRange.class::isInstance) .flatMap(agg -> ((InternalDateRange) agg).getBuckets().stream()) - .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) + .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds())) .collect(Collectors.toList()) ); }, listener::onFailure); @@ -617,8 +627,9 @@ public void getFeatureSamplesForPeriods( .asyncRequestWithInjectedSecurity( request, client::search, - detector.getId(), + config.getId(), client, + context, searchResponseListener ); } @@ -842,24 +853,25 @@ private SearchRequest createFeatureSearchRequest(AnomalyDetector detector, long } } - private SearchRequest createPreviewSearchRequest(AnomalyDetector detector, List> ranges) throws IOException { + private SearchRequest createPreviewSearchRequest(Config config, List> ranges) throws IOException { try { - SearchSourceBuilder searchSourceBuilder = ParseUtils.generatePreviewQuery(detector, ranges, xContent); - return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); + SearchSourceBuilder searchSourceBuilder = ParseUtils.generatePreviewQuery(config, ranges, xContent); + return new SearchRequest(config.getIndices().toArray(new String[0]), searchSourceBuilder); } catch (IOException e) { - logger.warn("Failed to create feature search request for " + detector.getId() + " for preview", e); + logger.warn("Failed to create feature search request for " + config.getId() + " for preview", e); throw e; } } public void getColdStartSamplesForPeriods( - AnomalyDetector detector, + Config config, List> ranges, - Entity entity, + Optional entity, boolean includesEmptyBucket, + AnalysisType context, ActionListener>> listener ) { - SearchRequest request = createColdStartFeatureSearchRequest(detector, ranges, entity); + SearchRequest request = createColdStartFeatureSearchRequest(config, ranges, entity); final ActionListener searchResponseListener = ActionListener.wrap(response -> { Aggregations aggs = response.getAggregations(); if (aggs == null) { @@ -889,7 +901,7 @@ public void getColdStartSamplesForPeriods( .filter(bucket -> bucket.getFrom() != null && bucket.getFrom() instanceof ZonedDateTime) .filter(bucket -> bucket.getDocCount() > docCountThreshold) .sorted(Comparator.comparing((Bucket bucket) -> (ZonedDateTime) bucket.getFrom())) - .map(bucket -> parseBucket(bucket, detector.getEnabledFeatureIds())) + .map(bucket -> parseBucket(bucket, config.getEnabledFeatureIds())) .collect(Collectors.toList()) ); }, listener::onFailure); @@ -899,15 +911,16 @@ public void getColdStartSamplesForPeriods( .asyncRequestWithInjectedSecurity( request, client::search, - detector.getId(), + config.getId(), client, + context, searchResponseListener ); } - private SearchRequest createColdStartFeatureSearchRequest(AnomalyDetector detector, List> ranges, Entity entity) { + private SearchRequest createColdStartFeatureSearchRequest(Config detector, List> ranges, Optional entity) { try { - SearchSourceBuilder searchSourceBuilder = ParseUtils.generateEntityColdStartQuery(detector, ranges, entity, xContent); + SearchSourceBuilder searchSourceBuilder = ParseUtils.generateColdStartQuery(detector, ranges, entity, xContent); return new SearchRequest(detector.getIndices().toArray(new String[0]), searchSourceBuilder); } catch (IOException e) { logger diff --git a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java similarity index 97% rename from src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java rename to src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java index cbd7ef78b..9849a67f8 100644 --- a/src/main/java/org/opensearch/ad/feature/SinglePointFeatures.java +++ b/src/main/java/org/opensearch/timeseries/feature/SinglePointFeatures.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.feature; +package org.opensearch.timeseries.feature; import java.util.Optional; diff --git a/src/main/java/org/opensearch/timeseries/function/BiCheckedFunction.java b/src/main/java/org/opensearch/timeseries/function/BiCheckedFunction.java new file mode 100644 index 000000000..d96b14adf --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/function/BiCheckedFunction.java @@ -0,0 +1,11 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.function; + +@FunctionalInterface +public interface BiCheckedFunction { + R apply(T t, F f) throws E; +} diff --git a/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java new file mode 100644 index 000000000..c6a223bdd --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/CheckpointDao.java @@ -0,0 +1,337 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; + +import org.apache.commons.pool2.impl.GenericObjectPool; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetAction; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.get.MultiGetAction; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.reindex.BulkByScrollResponse; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.ScrollableHitSource; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.util.ClientUtil; + +import com.google.gson.Gson; +import io.protostuff.LinkedBuffer; + +public abstract class CheckpointDao & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger logger = LogManager.getLogger(CheckpointDao.class); + public static final String TIMEOUT_LOG_MSG = "Timeout while deleting checkpoints of"; + public static final String BULK_FAILURE_LOG_MSG = "Bulk failure while deleting checkpoints of"; + public static final String SEARCH_FAILURE_LOG_MSG = "Search failure while deleting checkpoints of"; + public static final String DOC_GOT_DELETED_LOG_MSG = "checkpoints docs get deleted"; + public static final String INDEX_DELETED_LOG_MSG = "Checkpoint index has been deleted. Has nothing to do:"; + + // dependencies + protected final Client client; + protected final ClientUtil clientUtil; + + // configuration + protected final String indexName; + + protected Gson gson; + + // we won't read/write a checkpoint larger than a threshold + protected final int maxCheckpointBytes; + + protected final GenericObjectPool serializeRCFBufferPool; + protected final int serializeRCFBufferSize; + + protected final IndexManagement indexUtil; + protected final Clock clock; + public static final String NOT_ABLE_TO_DELETE_CHECKPOINT_MSG = "Cannot delete all checkpoints of detector"; + + public CheckpointDao( + Client client, + ClientUtil clientUtil, + String indexName, + Gson gson, + int maxCheckpointBytes, + GenericObjectPool serializeRCFBufferPool, + int serializeRCFBufferSize, + IndexManagementType indexUtil, + Clock clock + ) { + this.client = client; + this.clientUtil = clientUtil; + this.indexName = indexName; + this.gson = gson; + this.maxCheckpointBytes = maxCheckpointBytes; + this.serializeRCFBufferPool = serializeRCFBufferPool; + this.serializeRCFBufferSize = serializeRCFBufferSize; + this.indexUtil = indexUtil; + this.clock = clock; + } + + protected void putModelCheckpoint(String modelId, Map source, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + saveModelCheckpointAsync(source, modelId, listener); + } else { + onCheckpointNotExist(source, modelId, listener); + } + } + + /** + * Update the model doc using fields in source. This ensures we won't touch + * the old checkpoint and nodes with old/new logic can coexist in a cluster. + * This is useful for introducing compact rcf new model format. + * + * @param source fields to update + * @param modelId model Id, used as doc id in the checkpoint index + * @param listener Listener to return response + */ + protected void saveModelCheckpointAsync(Map source, String modelId, ActionListener listener) { + + UpdateRequest updateRequest = new UpdateRequest(indexName, modelId); + updateRequest.doc(source); + // If the document does not already exist, the contents of the upsert element are inserted as a new document. + // If the document exists, update fields in the map + updateRequest.docAsUpsert(true); + clientUtil + .asyncRequest( + updateRequest, + client::update, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void onCheckpointNotExist(Map source, String modelId, ActionListener listener) { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + saveModelCheckpointAsync(source, modelId, listener); + + } else { + throw new RuntimeException("Creating checkpoint with mappings call not acknowledged."); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + saveModelCheckpointAsync(source, modelId, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), exception); + } + })); + } + + protected Map.Entry checkoutOrNewBuffer() { + LinkedBuffer buffer = null; + boolean isCheckout = true; + try { + buffer = serializeRCFBufferPool.borrowObject(); + } catch (Exception e) { + logger.warn("Failed to borrow a buffer from pool", e); + } + if (buffer == null) { + buffer = LinkedBuffer.allocate(serializeRCFBufferSize); + isCheckout = false; + } + return new SimpleImmutableEntry(buffer, isCheckout); + } + + /** + * Deletes the model checkpoint for the model. + * + * @param modelId id of the model + * @param listener onReponse is called with null when the operation is completed + */ + public void deleteModelCheckpoint(String modelId, ActionListener listener) { + clientUtil + .asyncRequest( + new DeleteRequest(indexName, modelId), + client::delete, + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ); + } + + protected void logFailure(BulkByScrollResponse response, String id) { + if (response.isTimedOut()) { + logger.warn(CheckpointDao.TIMEOUT_LOG_MSG + " {}", id); + } else if (!response.getBulkFailures().isEmpty()) { + logger.warn(CheckpointDao.BULK_FAILURE_LOG_MSG + " {}", id); + for (BulkItemResponse.Failure bulkFailure : response.getBulkFailures()) { + logger.warn(bulkFailure); + } + } else { + logger.warn(CheckpointDao.SEARCH_FAILURE_LOG_MSG + " {}", id); + for (ScrollableHitSource.SearchFailure searchFailure : response.getSearchFailures()) { + logger.warn(searchFailure); + } + } + } + + /** + * Should we save the checkpoint or not + * @param lastCheckpointTIme Last checkpoint time + * @param forceWrite Save no matter what + * @param checkpointInterval Checkpoint interval + * @param clock UTC clock + * + * @return true when forceWrite is true or we haven't saved checkpoint in the + * last checkpoint interval; false otherwise + */ + public boolean shouldSave(Instant lastCheckpointTIme, boolean forceWrite, Duration checkpointInterval, Clock clock) { + return (lastCheckpointTIme != Instant.MIN && lastCheckpointTIme.plus(checkpointInterval).isBefore(clock.instant())) || forceWrite; + } + + public void batchWrite(BulkRequest request, ActionListener listener) { + if (indexUtil.doesCheckpointIndexExist()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + indexUtil.initCheckpointIndex(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + // create index failure. Notify callers using listener. + listener.onFailure(new TimeSeriesException("Creating checkpoint with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + clientUtil.execute(BulkAction.INSTANCE, request, listener); + } else { + logger.error(String.format(Locale.ROOT, "Unexpected error creating checkpoint index"), exception); + listener.onFailure(exception); + } + })); + } + } + + /** + * Serialized an EntityModel + * @param model input model + * @param modelId model id + * @return serialized string + */ + protected Optional toCheckpoint(Queue samples) { + if (samples == null) { + return Optional.empty(); + } + return Optional.of(samples.toArray()); + } + + public void batchRead(MultiGetRequest request, ActionListener listener) { + clientUtil.execute(MultiGetAction.INSTANCE, request, listener); + } + + public void read(GetRequest request, ActionListener listener) { + clientUtil.execute(GetAction.INSTANCE, request, listener); + } + + /** + * Delete checkpoints associated with a config. Used in multi-entity detector. + * @param configId Config Id + */ + public void deleteModelCheckpointByConfigId(String configId) { + // A bulk delete request is performed for each batch of matching documents. If a + // search or bulk request is rejected, the requests are retried up to 10 times, + // with exponential back off. If the maximum retry limit is reached, processing + // halts and all failed requests are returned in the response. Any delete + // requests that completed successfully still stick, they are not rolled back. + DeleteByQueryRequest deleteRequest = createDeleteCheckpointRequest(configId); + logger.info("Delete checkpoints of config {}", configId); + client.execute(DeleteByQueryAction.INSTANCE, deleteRequest, ActionListener.wrap(response -> { + if (response.isTimedOut() || !response.getBulkFailures().isEmpty() || !response.getSearchFailures().isEmpty()) { + logFailure(response, configId); + } + // can return 0 docs get deleted because: + // 1) we cannot find matching docs + // 2) bad stats from OpenSearch. In this case, docs are deleted, but + // OpenSearch says deleted is 0. + logger.info("{} " + CheckpointDao.DOC_GOT_DELETED_LOG_MSG, response.getDeleted()); + }, exception -> { + if (exception instanceof IndexNotFoundException) { + logger.info(CheckpointDao.INDEX_DELETED_LOG_MSG + " {}", configId); + } else { + // Gonna eventually delete in daily cron. + logger.error(NOT_ABLE_TO_DELETE_CHECKPOINT_MSG, exception); + } + })); + } + + protected Optional> processRawCheckpoint(GetResponse response) { + return Optional.ofNullable(response).filter(GetResponse::isExists).map(GetResponse::getSource); + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processHCGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromEntityModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + /** + * Process a checkpoint GetResponse and return the EntityModel object + * @param response Checkpoint Index GetResponse + * @param modelId Model Id + * @return a pair of entity model and its last checkpoint time + */ + public ModelState processSingleStreamGetResponse(GetResponse response, String modelId, String configId) { + Optional> checkpointString = processRawCheckpoint(response); + if (checkpointString.isPresent()) { + return fromSingleStreamModelCheckpoint(checkpointString.get(), modelId, configId); + } else { + return null; + } + } + + protected abstract ModelState fromEntityModelCheckpoint(Map checkpoint, String modelId, String configId); + + protected abstract ModelState fromSingleStreamModelCheckpoint( + Map checkpoint, + String modelId, + String configId + ); + + public abstract Map toIndexSource(ModelState modelState) throws IOException; + + protected abstract DeleteByQueryRequest createDeleteCheckpointRequest(String configId); +} diff --git a/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java new file mode 100644 index 000000000..8e5b7bcf3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/MemoryAwareConcurrentHashmap.java @@ -0,0 +1,102 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.MemoryTracker.Origin; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * A customized ConcurrentHashMap that can automatically consume and release memory. + * This enables minimum change to our single-stream code as we just have to replace + * the map implementation. + * + * Note: this is mainly used for single-stream configs. The key is model id. + */ +public class MemoryAwareConcurrentHashmap extends + ConcurrentHashMap> { + protected final MemoryTracker memoryTracker; + + public MemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) { + this.memoryTracker = memoryTracker; + } + + @Override + public ModelState remove(Object key) { + ModelState deletedModelState = super.remove(key); + if (deletedModelState != null && deletedModelState.getModel().isPresent()) { + long memoryToRelease = memoryTracker.estimateTRCFModelSize(deletedModelState.getModel().get()); + memoryTracker.releaseMemory(memoryToRelease, true, Origin.REAL_TIME_DETECTOR); + } + return deletedModelState; + } + + @Override + public ModelState put(String key, ModelState value) { + ModelState previousAssociatedState = super.put(key, value); + if (value != null && value.getModel().isPresent()) { + long memoryToConsume = memoryTracker.estimateTRCFModelSize(value.getModel().get()); + memoryTracker.consumeMemory(memoryToConsume, true, Origin.REAL_TIME_DETECTOR); + } + return previousAssociatedState; + } + + /** + * Gets all of a config's model sizes hosted on a node + * + * @param id Analysis Id + * @return a map of model id to its memory size + */ + public Map getModelSize(String configId) { + Map res = new HashMap<>(); + super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .forEach(entry -> { + Optional modelOptional = entry.getValue().getModel(); + if (modelOptional.isPresent()) { + res.put(entry.getKey(), memoryTracker.estimateTRCFModelSize(modelOptional.get())); + } + }); + return res; + } + + /** + * Checks if a model exists for the given config. + * @param config Config Id + * @return `true` if the model exists, `false` otherwise. + */ + public boolean doesModelExist(String configId) { + return super.entrySet() + .stream() + .filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId(entry.getKey()).equals(configId)) + .anyMatch(n -> true); + } + + public boolean hostIfPossible(String modelId, ModelState toUpdate) { + return Optional + .ofNullable(toUpdate) + .filter(state -> state.getModel().isPresent()) + .filter(state -> memoryTracker.isHostingAllowed(modelId, state.getModel().get())) + .map(state -> { + super.put(modelId, toUpdate); + return true; + }) + .orElse(false); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java new file mode 100644 index 000000000..c1cf8c698 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelColdStart.java @@ -0,0 +1,586 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import java.util.stream.DoubleStream; +import java.util.stream.Stream; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.core.util.Throwables; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.CleanState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.DoorKeeper; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.dataprocessor.Imputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * The class bootstraps a model by performing a cold start + * + * @param Node state type + * @param RCF model type + * @param CheckpointDao type + * @param CheckpointWriteWorkerType + */ +public abstract class ModelColdStart & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker> + implements + MaintenanceState, + CleanState { + private static final Logger logger = LogManager.getLogger(ModelColdStart.class); + + private final Duration modelTtl; + + // A bloom filter checked before cold start to ensure we don't repeatedly + // retry cold start of the same model. + // keys are detector ids. + protected Map doorKeepers; + protected Instant lastThrottledColdStartTime; + protected int coolDownMinutes; + protected final Clock clock; + protected final ThreadPool threadPool; + protected final int numMinSamples; + protected CheckpointWriteWorkerType checkpointWriteWorker; + // make sure rcf use a specific random seed. Otherwise, we will use a random random (not a typo) seed. + // this is mainly used for testing to make sure the model we trained and the reference rcf produce + // the same results + protected final long rcfSeed; + protected final int numberOfTrees; + protected final int rcfSampleSize; + protected final double thresholdMinPvalue; + protected final double rcfTimeDecay; + protected final double initialAcceptFraction; + protected final NodeStateManager nodeStateManager; + protected final int defaulStrideLength; + protected final int defaultNumberOfSamples; + protected final SearchFeatureDao searchFeatureDao; + protected final FeatureManager featureManager; + protected final int maxRoundofColdStart; + protected final String threadPoolName; + protected final AnalysisType context; + + public ModelColdStart( + Duration modelTtl, + int coolDownMinutes, + Clock clock, + ThreadPool threadPool, + int numMinSamples, + CheckpointWriteWorkerType checkpointWriteWorker, + long rcfSeed, + int numberOfTrees, + int rcfSampleSize, + double thresholdMinPvalue, + double rcfTimeDecay, + NodeStateManager nodeStateManager, + int defaultSampleStride, + int defaultTrainSamples, + SearchFeatureDao searchFeatureDao, + FeatureManager featureManager, + int maxRoundofColdStart, + String threadPoolName, + AnalysisType context + ) { + this.modelTtl = modelTtl; + this.coolDownMinutes = coolDownMinutes; + this.clock = clock; + this.threadPool = threadPool; + this.numMinSamples = numMinSamples; + this.checkpointWriteWorker = checkpointWriteWorker; + this.rcfSeed = rcfSeed; + this.numberOfTrees = numberOfTrees; + this.rcfSampleSize = rcfSampleSize; + this.thresholdMinPvalue = thresholdMinPvalue; + this.rcfTimeDecay = rcfTimeDecay; + + this.doorKeepers = new ConcurrentHashMap<>(); + this.lastThrottledColdStartTime = Instant.MIN; + this.initialAcceptFraction = numMinSamples * 1.0d / rcfSampleSize; + + this.nodeStateManager = nodeStateManager; + this.defaulStrideLength = defaultSampleStride; + this.defaultNumberOfSamples = defaultTrainSamples; + this.searchFeatureDao = searchFeatureDao; + this.featureManager = featureManager; + this.maxRoundofColdStart = maxRoundofColdStart; + this.threadPoolName = threadPoolName; + this.context = context; + } + + @Override + public void maintenance() { + doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> { + String id = doorKeeperEntry.getKey(); + DoorKeeper doorKeeper = doorKeeperEntry.getValue(); + if (doorKeeper.expired(modelTtl)) { + doorKeepers.remove(id); + } else { + doorKeeper.maintenance(); + } + }); + } + + @Override + public void clear(String id) { + doorKeepers.remove(id); + } + + /** + * Train models + * @param entity The entity info if we are training for an HC entity + * @param configId Config Id + * @param modelState Model state + * @param listener callback before the method returns whenever ColdStarter + * finishes training or encounters exceptions. The listener helps notify the + * cold start queue to pull another request (if any) to execute. + */ + public void trainModel(Optional entity, String configId, ModelState modelState, ActionListener listener) { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(detectorOptional -> { + if (false == detectorOptional.isPresent()) { + logger.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); + listener.onFailure(new TimeSeriesException(configId, "fail to find config")); + return; + } + + Config config = detectorOptional.get(); + + String modelId = modelState.getModelId(); + + if (modelState.getSamples().size() < this.numMinSamples) { + // we cannot get last RCF score since cold start happens asynchronously + coldStart(modelId, entity, modelState, config, listener); + } else { + try { + trainModelFromExistingSamples(modelState, entity, config); + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + } + }, listener::onFailure)); + } + + public void trainModelFromExistingSamples(ModelState modelState, Optional entity, Config config) { + Pair continuousSamples = featureManager.getContinuousSamples(config, modelState.getSamples()); + trainModelFromDataSegments(continuousSamples, entity, modelState, config); + } + + /** + * Training model + * @param modelId model Id corresponding to the entity + * @param entity the entity's information if we are training for HC entity + * @param modelState model state + * @param config config accessor + * @param listener call back to call after cold start + */ + private void coldStart( + String modelId, + Optional entity, + ModelState modelState, + Config config, + ActionListener listener + ) { + logger.debug("Trigger cold start for {}", modelId); + + if (modelState == null) { + listener.onFailure(new IllegalArgumentException(String.format(Locale.ROOT, "Cannot have empty model state"))); + return; + } + + if (lastThrottledColdStartTime.plus(Duration.ofMinutes(coolDownMinutes)).isAfter(clock.instant())) { + listener.onResponse(null); + return; + } + + String configId = config.getId(); + boolean earlyExit = true; + try { + DoorKeeper doorKeeper = doorKeepers + .computeIfAbsent( + configId, + id -> { + // reset every 60 intervals + return new DoorKeeper( + TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION, + TimeSeriesSettings.DOOR_KEEPER_FALSE_POSITIVE_RATE, + config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ), + clock + ); + } + ); + + // Won't retry cold start within 60 intervals for an entity + if (doorKeeper.mightContain(modelId)) { + return; + } + + doorKeeper.put(modelId); + + ActionListener> coldStartCallBack = ActionListener.wrap(trainingData -> { + try { + if (trainingData != null && trainingData.getKey() != null) { + double[][] dataPoints = trainingData.getKey(); + // only train models if we have enough samples + if (dataPoints.length >= numMinSamples) { + // The function trainModelFromDataSegments will save a trained a model. trainModelFromDataSegments is called by + // multiple places so I want to make the saving model implicit just in case I forgot. + trainModelFromDataSegments(trainingData, entity, modelState, config); + logger.info("Succeeded in training entity: {}", modelId); + } else { + // save to checkpoint + checkpointWriteWorker.write(modelState, true, RequestPriority.MEDIUM); + logger.info("Not enough data to train model: {}, currently we have {}", modelId, dataPoints.length); + } + } else { + logger.info("Cannot get training data for {}", modelId); + } + listener.onResponse(null); + } catch (Exception e) { + listener.onFailure(e); + } + }, exception -> { + try { + logger.error(new ParameterizedMessage("Error while cold start {}", modelId), exception); + Throwable cause = Throwables.getRootCause(exception); + if (ExceptionUtil.isOverloaded(cause)) { + logger.error("too many requests"); + lastThrottledColdStartTime = Instant.now(); + } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) { + // e.g., cannot find anomaly detector + nodeStateManager.setException(configId, exception); + } else { + nodeStateManager.setException(configId, new TimeSeriesException(configId, cause)); + } + listener.onFailure(exception); + } catch (Exception e) { + listener.onFailure(e); + } + }); + + threadPool + .executor(threadPoolName) + .execute( + () -> getColdStartData( + configId, + entity, + config.getImputer(), + new ThreadedActionListener<>(logger, threadPool, threadPoolName, coldStartCallBack, false) + ) + ); + earlyExit = false; + } finally { + if (earlyExit) { + listener.onResponse(null); + } + } + } + + /** + * Get training data for an entity. + * + * We first note the maximum and minimum timestamp, and sample at most 24 points + * (with 60 points apart between two neighboring samples) between those minimum + * and maximum timestamps. Samples can be missing. We only interpolate points + * between present neighboring samples. We then transform samples and interpolate + * points to shingles. Finally, full shingles will be used for cold start. + * + * @param configId config Id + * @param entity the entity's information + * @param listener listener to return training data + */ + private void getColdStartData( + String configId, + Optional entity, + Imputer imputer, + ActionListener> listener + ) { + ActionListener> getDetectorListener = ActionListener.wrap(configOp -> { + if (!configOp.isPresent()) { + listener.onFailure(new EndRunException(configId, "Config is not available.", false)); + return; + } + Config config = configOp.get(); + + ActionListener> minTimeListener = ActionListener.wrap(earliest -> { + if (earliest.isPresent()) { + long startTimeMs = earliest.get().longValue(); + + // End time uses milliseconds as start time is assumed to be in milliseconds. + // Opensearch uses a set of preconfigured formats to recognize and parse these + // strings into a long value + // representing milliseconds-since-the-epoch in UTC. + // More on https://tinyurl.com/wub4fk92 + + long endTimeMs = clock.millis(); + Pair params = selectRangeParam(config); + int stride = params.getLeft(); + int numberOfSamples = params.getRight(); + + // we start with round 0 + getFeatures( + listener, + 0, + Pair.of(new double[0][0], new Sample()), + config, + entity, + stride, + numberOfSamples, + startTimeMs, + endTimeMs, + imputer + ); + } else { + listener.onResponse(Pair.of(new double[0][0], new Sample())); + } + }, listener::onFailure); + + searchFeatureDao + .getMinDataTime( + config, + entity, + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, minTimeListener, false) + ); + + }, listener::onFailure); + + nodeStateManager + .getConfig(configId, context, new ThreadedActionListener<>(logger, threadPool, threadPoolName, getDetectorListener, false)); + } + + /** + * Select strideLength and numberOfSamples, where stride is the number of intervals + * between two samples and trainSamples is training samples to fetch. If we disable + * interpolation, strideLength is 1 and numberOfSamples is shingleSize + numMinSamples; + * + * Algorithm: + * + * delta is the length of the detector interval in minutes. + * + * 1. Suppose delta ≤ 30 and divides 60. Then set numberOfSamples = ceil ( (shingleSize + 32)/ 24 )*24 + * and strideLength = 60/delta. Note that if there is enough data — we may have lot more than shingleSize+32 + * points — which is only good. This step tries to match data with hourly pattern. + * 2. otherwise, set numberOfSamples = (shingleSize + 32) and strideLength = 1. + * This should be an uncommon case as we are assuming most users think in terms of multiple of 5 minutes + *(say 10 or 30 minutes). But if someone wants a 23 minutes interval —- and the system permits -- + * we give it to them. In this case, we disable interpolation as we want to interpolate based on the hourly pattern. + * That's why we use 60 as a dividend in case 1. The 23 minute case does not fit that pattern. + * Note the smallest delta that does not divide 60 is 7 which is quite large to wait for one data point. + * @return the chosen strideLength and numberOfSamples + */ + private Pair selectRangeParam(Config config) { + int shingleSize = config.getShingleSize(); + if (isInterpolationInColdStartEnabled()) { + long delta = config.getIntervalInMinutes(); + + int strideLength = defaulStrideLength; + int numberOfSamples = defaultNumberOfSamples; + if (delta <= 30 && 60 % delta == 0) { + strideLength = (int) (60 / delta); + numberOfSamples = (int) Math.ceil((shingleSize + numMinSamples) / 24.0d) * 24; + } else { + strideLength = 1; + numberOfSamples = shingleSize + numMinSamples; + } + return Pair.of(strideLength, numberOfSamples); + } else { + return Pair.of(1, shingleSize + numMinSamples); + } + + } + + private void getFeatures( + ActionListener> listener, + int round, + Pair lastRounddataSample, + Config config, + Optional entity, + int stride, + int numberOfSamples, + long startTimeMs, + long endTimeMs, + Imputer imputer + ) { + if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < config.getIntervalInMilliseconds()) { + listener.onResponse(lastRounddataSample); + return; + } + + // Create ranges in descending order to make sure the last sample's end time is the given endTimeMs. + // We will reorder the ranges in ascending order in Opensearch's response. + List> sampleRanges = getTrainSampleRanges(config, startTimeMs, endTimeMs, stride, numberOfSamples); + + if (sampleRanges.isEmpty()) { + listener.onResponse(lastRounddataSample); + return; + } + + ActionListener>> getFeaturelistener = ActionListener.wrap(featureSamples -> { + + if (featureSamples.size() != sampleRanges.size()) { + logger + .error( + "We don't expect different featureSample size {} and sample range size {}.", + featureSamples.size(), + sampleRanges.size() + ); + listener.onResponse(lastRounddataSample); + return; + } + + int totalNumSamples = featureSamples.size(); + int numEnabledFeatures = config.getEnabledFeatureIds().size(); + double[][] trainingData = new double[totalNumSamples][numEnabledFeatures]; + + // featuresSamples are in ascending order of time. + for (int index = 0; index < featureSamples.size(); index++) { + Optional featuresOptional = featureSamples.get(index); + if (featuresOptional.isPresent()) { + // the order of the elements in the Stream is the same as the order of the elements in the List entry.getValue() + trainingData[index] = featuresOptional.get(); + } else { + // create an array of Double.NaN + trainingData[index] = DoubleStream.generate(() -> Double.NaN).limit(numEnabledFeatures).toArray(); + } + } + + double[][] currentRoundColdStartData = imputer.impute(trainingData, totalNumSamples); + + Pair concatenatedDataSample = null; + double[][] lastRoundColdStartData = lastRounddataSample.getKey(); + // make sure the following logic making sense via checking lastRoundFirstStartTime > 0 + if (lastRoundColdStartData != null && lastRoundColdStartData.length > 0) { + double[][] concatenated = new double[currentRoundColdStartData.length + lastRoundColdStartData.length][numEnabledFeatures]; + System.arraycopy(lastRoundColdStartData, 0, concatenated, 0, lastRoundColdStartData.length); + System + .arraycopy(currentRoundColdStartData, 0, concatenated, lastRoundColdStartData.length, currentRoundColdStartData.length); + trainingData = imputer.impute(concatenated, concatenated.length); + concatenatedDataSample = Pair.of(trainingData, lastRounddataSample.getValue()); + } else { + concatenatedDataSample = Pair + .of( + currentRoundColdStartData, + new Sample( + currentRoundColdStartData[currentRoundColdStartData.length - 1], + Instant.ofEpochMilli(endTimeMs - config.getIntervalInMilliseconds()), + Instant.ofEpochMilli(endTimeMs) + ) + ); + } + + // If the first round of probe provides (32+shingleSize) points (note that if S0 is + // missing or all Si​ for some i > N is missing then we would miss a lot of points. + // Otherwise we can issue another round of query — if there is any sample in the + // second round then we would have 32 + shingleSize points. If there is no sample + // in the second round then we should wait for real data. + if (currentRoundColdStartData.length >= config.getShingleSize() + numMinSamples || round + 1 >= maxRoundofColdStart) { + listener.onResponse(concatenatedDataSample); + } else { + // the earliest sample's start time is the endTimeMs of next round of probe. + long earliestSampleStartTime = sampleRanges.get(sampleRanges.size() - 1).getKey(); + getFeatures( + listener, + round + 1, + concatenatedDataSample, + config, + entity, + stride, + numberOfSamples, + startTimeMs, + earliestSampleStartTime, + imputer + ); + } + }, listener::onFailure); + + try { + searchFeatureDao + .getColdStartSamplesForPeriods( + config, + sampleRanges, + entity, + // Accept empty bucket. + // 0, as returned by the engine should constitute a valid answer, “null” is a missing answer — it may be that 0 + // is meaningless in some case, but 0 is also meaningful in some cases. It may be that the query defining the + // metric is ill-formed, but that cannot be solved by cold-start strategy of the AD plugin — if we attempt to do + // that, we will have issues with legitimate interpretations of 0. + true, + context, + new ThreadedActionListener<>(logger, threadPool, threadPoolName, getFeaturelistener, false) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Get train samples within a time range. + * + * @param config accessor to config + * @param startMilli range start + * @param endMilli range end + * @param stride the number of intervals between two samples + * @param numberOfSamples maximum training samples to fetch + * @return list of sample time ranges + */ + private List> getTrainSampleRanges(Config config, long startMilli, long endMilli, int stride, int numberOfSamples) { + long bucketSize = ((IntervalTimeConfiguration) config.getInterval()).toDuration().toMillis(); + int numBuckets = (int) Math.floor((endMilli - startMilli) / (double) bucketSize); + // adjust if numStrides is more than the max samples + int numStrides = Math.min((int) Math.floor(numBuckets / (double) stride), numberOfSamples); + List> sampleRanges = Stream + .iterate(endMilli, i -> i - stride * bucketSize) + .limit(numStrides) + .map(time -> new SimpleImmutableEntry<>(time - bucketSize, time)) + .collect(Collectors.toList()); + return sampleRanges; + } + + protected abstract void trainModelFromDataSegments( + Pair dataPoints, + Optional entity, + ModelState state, + Config config + ); + + protected abstract boolean isInterpolationInColdStartEnabled(); +} diff --git a/src/main/java/org/opensearch/timeseries/ml/ModelManager.java b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java new file mode 100644 index 000000000..de2b0f3b1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/ModelManager.java @@ -0,0 +1,192 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import java.time.Clock; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.timeseries.MemoryTracker; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; + +import com.amazon.randomcutforest.RandomCutForest; +import com.amazon.randomcutforest.parkservices.AnomalyDescriptor; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ModelManager & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart> { + + private static final Logger LOG = LogManager.getLogger(ModelManager.class); + + public enum ModelType { + RCF("rcf"), + THRESHOLD("threshold"), + TRCF("trcf"), + RCFCASTER("rcf_caster"); + + private String name; + + ModelType(String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + protected final int rcfNumTrees; + protected final int rcfNumSamplesInTree; + protected final double rcfTimeDecay; + protected final int rcfNumMinSamples; + protected ColdStarterType coldStarter; + protected MemoryTracker memoryTracker; + protected final Clock clock; + protected FeatureManager featureManager; + protected final CheckpointDaoType checkpointDao; + + public ModelManager( + int rcfNumTrees, + int rcfNumSamplesInTree, + double rcfTimeDecay, + int rcfNumMinSamples, + ColdStarterType coldStarter, + MemoryTracker memoryTracker, + Clock clock, + FeatureManager featureManager, + CheckpointDaoType checkpointDao + ) { + this.rcfNumTrees = rcfNumTrees; + this.rcfNumSamplesInTree = rcfNumSamplesInTree; + this.rcfTimeDecay = rcfTimeDecay; + this.rcfNumMinSamples = rcfNumMinSamples; + this.coldStarter = coldStarter; + this.memoryTracker = memoryTracker; + this.clock = clock; + this.featureManager = featureManager; + this.checkpointDao = checkpointDao; + } + + public ResultType getResult( + Sample sample, + ModelState modelState, + String modelId, + Optional entity, + Config config + ) { + ResultType result = createEmptyResult(); + if (modelState != null) { + Optional entityModel = modelState.getModel(); + + if (entityModel.isEmpty()) { + coldStarter.trainModelFromExistingSamples(modelState, entity, config); + } + + if (modelState.getModel().isPresent()) { + result = score(sample, modelId, modelState, config); + } else { + modelState.addSample(sample); + } + } + return result; + } + + public void clearModels(String detectorId, Map models, ActionListener listener) { + Iterator id = models.keySet().iterator(); + clearModelForIterator(detectorId, models, id, listener); + } + + protected void clearModelForIterator(String detectorId, Map models, Iterator idIter, ActionListener listener) { + if (idIter.hasNext()) { + String modelId = idIter.next(); + if (SingleStreamModelIdMapper.getConfigIdForModelId(modelId).equals(detectorId)) { + models.remove(modelId); + checkpointDao + .deleteModelCheckpoint( + modelId, + ActionListener.wrap(r -> clearModelForIterator(detectorId, models, idIter, listener), listener::onFailure) + ); + } else { + clearModelForIterator(detectorId, models, idIter, listener); + } + } else { + listener.onResponse(null); + } + } + + @SuppressWarnings("unchecked") + public ResultType score( + Sample sample, + String modelId, + ModelState modelState, + Config config + ) { + + ResultType result = createEmptyResult(); + Optional model = modelState.getModel(); + try { + if (model != null && model.isPresent()) { + RCFModelType rcfModel = model.get(); + + Pair dataSamplePair = featureManager + .getContinuousSamples(config, modelState.getSamples(), modelState.getLastProcessedSample(), sample); + + double[][] data = dataSamplePair.getKey(); + RCFDescriptor lastResult = null; + for (int i = 0; i < data.length; i++) { + // we are sure that the process method will indeed return an instance of RCFDescriptor. + lastResult = (RCFDescriptor) rcfModel.process(data[i], 0); + } + modelState.clearSamples(); + + if (lastResult != null) { + result = toResult(rcfModel.getForest(), lastResult); + } + + modelState.setLastProcessedSample(dataSamplePair.getValue()); + } + } catch (Exception e) { + LOG + .error( + new ParameterizedMessage( + "Fail to score for [{}]: model Id [{}], feature [{}]", + modelState.getEntity().isEmpty() ? modelState.getConfigId() : modelState.getEntity().get(), + modelId, + Arrays.toString(sample.getValueList()) + ), + e + ); + throw e; + } finally { + modelState.setLastUsedTime(clock.instant()); + } + return result; + } + + protected abstract ResultType createEmptyResult(); + + protected abstract ResultType toResult( + RandomCutForest forecast, + RCFDescriptor castDescriptor + ); +} diff --git a/src/main/java/org/opensearch/ad/ml/ModelState.java b/src/main/java/org/opensearch/timeseries/ml/ModelState.java similarity index 56% rename from src/main/java/org/opensearch/ad/ml/ModelState.java rename to src/main/java/org/opensearch/timeseries/ml/ModelState.java index 9e909bc58..d628adbc9 100644 --- a/src/main/java/org/opensearch/ad/ml/ModelState.java +++ b/src/main/java/org/opensearch/timeseries/ml/ModelState.java @@ -9,92 +9,98 @@ * GitHub history for details. */ -package org.opensearch.ad.ml; +package org.opensearch.timeseries.ml; import java.time.Clock; import java.time.Duration; import java.time.Instant; +import java.util.ArrayDeque; +import java.util.Deque; import java.util.HashMap; import java.util.Map; +import java.util.Optional; -import org.opensearch.ad.ExpiringState; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Entity; -/** - * A ML model and states such as usage. - */ -public class ModelState implements ExpiringState { - +public class ModelState implements org.opensearch.timeseries.ExpiringState { public static String MODEL_TYPE_KEY = "model_type"; public static String LAST_USED_TIME_KEY = "last_used_time"; public static String LAST_CHECKPOINT_TIME_KEY = "last_checkpoint_time"; public static String PRIORITY_KEY = "priority"; - private T model; - private String modelId; - private String detectorId; - private String modelType; + + protected T model; + protected String modelId; + protected String configId; + protected String modelType; // time when the ML model was used last time - private Instant lastUsedTime; - private Instant lastCheckpointTime; - private Clock clock; - private float priority; + protected Instant lastUsedTime; + protected Instant lastCheckpointTime; + protected Clock clock; + protected float priority; + protected Sample lastProcessedSample; + protected Deque samples; + protected Optional entity; /** * Constructor. * * @param model ML model * @param modelId Id of model partition - * @param detectorId Id of detector this model partition is used for + * @param configId Id of analysis this model partition is used for * @param modelType type of model * @param clock UTC clock * @param priority Priority of the model state. Used in multi-entity detectors' cache. + * @param lastProcessedSample last processed sample. Used in interpolation. + * @param entity Entity info if this is a HC entity state + * @param samples existing samples that haven't been processed */ - public ModelState(T model, String modelId, String detectorId, String modelType, Clock clock, float priority) { + public ModelState( + T model, + String modelId, + String configId, + String modelType, + Clock clock, + float priority, + Sample lastProcessedSample, + Optional entity, + Deque samples + ) { this.model = model; this.modelId = modelId; - this.detectorId = detectorId; + this.configId = configId; this.modelType = modelType; this.lastUsedTime = clock.instant(); // this is inaccurate until we find the last checkpoint time from disk this.lastCheckpointTime = Instant.MIN; this.clock = clock; this.priority = priority; + this.lastProcessedSample = lastProcessedSample; + this.entity = entity; + this.samples = samples; } /** - * Create state with zero priority. Used in single-entity detector. + * Constructor. Used in single-stream analysis. * - * @param Model object's type - * @param model The actual model object - * @param modelId Model Id - * @param detectorId Detector Id - * @param modelType Model type like RCF model + * @param model ML model + * @param modelId Id of model partition + * @param configId Id of analysis this model partition is used for + * @param modelType type of model * @param clock UTC clock - * - * @return the created model state + * @param lastProcessedSample last processed sample. Used in interpolation. + * @param samples existing samples that haven't been processed */ - public static ModelState createSingleEntityModelState( + public ModelState( T model, String modelId, - String detectorId, + String configId, String modelType, - Clock clock + Clock clock, + Sample lastProcessedSample, + Deque samples ) { - return new ModelState<>(model, modelId, detectorId, modelType, clock, 0f); - } - - /** - * Returns the ML model. - * - * @return the ML model. - */ - public T getModel() { - return this.model; - } - - public void setModel(T model) { - this.model = model; + this(model, modelId, configId, modelType, clock, 0, lastProcessedSample, Optional.empty(), new ArrayDeque<>()); } /** @@ -106,15 +112,6 @@ public String getModelId() { return modelId; } - /** - * Gets the detectorID of the model - * - * @return detectorId associated with the model - */ - public String getId() { - return detectorId; - } - /** * Gets the type of the model * @@ -172,16 +169,90 @@ public void setPriority(float priority) { this.priority = priority; } + public Sample getLastProcessedSample() { + return lastProcessedSample; + } + + public void setLastProcessedSample(Sample lastProcessedSample) { + this.lastProcessedSample = lastProcessedSample; + } + + @Override + public boolean expired(Duration stateTtl) { + return expired(lastUsedTime, stateTtl, clock.instant()); + } + + /** + * Gets the Config ID of the model + * + * @return the config id associated with the model + */ + public String getConfigId() { + return configId; + } + + /** + * In old checkpoint mapping, we don't have entity. It's fine we are missing + * entity as it is mostly used for debugging. + * @return entity + */ + public Optional getEntity() { + return entity; + } + + public Deque getSamples() { + return this.samples; + } + + public void addSample(Sample sample) { + if (this.samples == null) { + this.samples = new ArrayDeque<>(); + } + if (sample != null && sample.getValueList() != null && sample.getValueList().length != 0) { + this.samples.add(sample); + } + } + + /** + * Sets a model. + * + * @param model model instance + */ + public void setModel(T model) { + this.model = model; + } + + /** + * + * @return optional model. + */ + public Optional getModel() { + return Optional.ofNullable(this.model); + } + + public void clearSamples() { + if (samples != null) { + samples.clear(); + } + } + + public void clear() { + clearSamples(); + model = null; + lastProcessedSample = null; + } + /** * Gets the Model State as a map * * @return Map of ModelStates */ + @SuppressWarnings("serial") public Map getModelStateAsMap() { return new HashMap() { { put(CommonName.MODEL_ID_FIELD, modelId); - put(ADCommonName.DETECTOR_ID_KEY, detectorId); + put(CommonName.CONFIG_ID_KEY, configId); put(MODEL_TYPE_KEY, modelType); /* A stats API broadcasts requests to all nodes and renders node responses using toXContent. * @@ -195,18 +266,10 @@ public Map getModelStateAsMap() { if (lastCheckpointTime != Instant.MIN) { put(LAST_CHECKPOINT_TIME_KEY, lastCheckpointTime.toEpochMilli()); } - if (model != null && model instanceof EntityModel) { - EntityModel summary = (EntityModel) model; - if (summary.getEntity().isPresent()) { - put(CommonName.ENTITY_KEY, summary.getEntity().get().toStat()); - } + if (entity.isPresent()) { + put(CommonName.ENTITY_KEY, entity.get().toStat()); } } }; } - - @Override - public boolean expired(Duration stateTtl) { - return expired(lastUsedTime, stateTtl, clock.instant()); - } } diff --git a/src/main/java/org/opensearch/timeseries/ml/Sample.java b/src/main/java/org/opensearch/timeseries/ml/Sample.java new file mode 100644 index 000000000..2379cd13c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/Sample.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.util.ParseUtils; + +public class Sample implements ToXContentObject { + private final double[] data; + private final Instant dataStartTime; + private final Instant dataEndTime; + + public Sample(double[] data, Instant dataStartTime, Instant dataEndTime) { + super(); + this.data = data; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + // Invalid sample + public Sample() { + this.data = new double[0]; + this.dataStartTime = this.dataEndTime = Instant.MIN; + } + + public double[] getValueList() { + return data; + } + + public Instant getDataStartTime() { + return dataStartTime; + } + + public Instant getDataEndTime() { + return dataEndTime; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + XContentBuilder xContentBuilder = builder.startObject(); + if (data != null) { + xContentBuilder.array(CommonName.VALUE_LIST_FIELD, data); + } + if (dataStartTime != null) { + xContentBuilder.field(CommonName.DATA_START_TIME_FIELD, dataStartTime.toEpochMilli()); + } + if (dataEndTime != null) { + xContentBuilder.field(CommonName.DATA_END_TIME_FIELD, dataEndTime.toEpochMilli()); + } + return xContentBuilder.endObject(); + } + + public static Sample parse(XContentParser parser) throws IOException { + Instant dataStartTime = null; + Instant dataEndTime = null; + List valueList = new ArrayList<>(); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + + switch (fieldName) { + case CommonName.DATA_START_TIME_FIELD: + dataStartTime = ParseUtils.toInstant(parser); + break; + case CommonName.DATA_END_TIME_FIELD: + dataEndTime = ParseUtils.toInstant(parser); + break; + case CommonName.VALUE_LIST_FIELD: + ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + valueList.add(parser.doubleValue()); + } + break; + default: + parser.skipChildren(); + break; + } + } + + return new Sample(valueList.stream().mapToDouble(Double::doubleValue).toArray(), dataStartTime, dataEndTime); + } + + public boolean isInvalid() { + return dataStartTime.compareTo(Instant.MIN) == 0 || dataEndTime.compareTo(Instant.MIN) == 0; + } + + @Override + public String toString() { + return "Sample [data=" + Arrays.toString(data) + ", dataStartTime=" + dataStartTime + ", dataEndTime=" + dataEndTime + "]"; + } +} diff --git a/src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java similarity index 74% rename from src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java rename to src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java index ac3ce899d..cf045f79d 100644 --- a/src/main/java/org/opensearch/ad/ml/SingleStreamModelIdMapper.java +++ b/src/main/java/org/opensearch/timeseries/ml/SingleStreamModelIdMapper.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ml; +package org.opensearch.timeseries.ml; import java.util.Locale; import java.util.regex.Matcher; @@ -22,9 +22,10 @@ * */ public class SingleStreamModelIdMapper { - protected static final String DETECTOR_ID_PATTERN = "(.*)_model_.+"; + protected static final String CONFIG_ID_PATTERN = "(.*)_model_.+"; protected static final String RCF_MODEL_ID_PATTERN = "%s_model_rcf_%d"; protected static final String THRESHOLD_MODEL_ID_PATTERN = "%s_model_threshold"; + protected static final String CASTER_MODEL_ID_PATTERN = "%s_model_caster"; /** * Returns the model ID for the RCF model partition. @@ -48,14 +49,24 @@ public static String getThresholdModelId(String detectorId) { } /** - * Gets the detector id from the model id. + * Returns the model ID for the rcf caster model. + * + * @param forecasterId ID of the forecaster for which the model is trained + * @return ID for the forecaster model + */ + public static String getCasterModelId(String forecasterId) { + return String.format(Locale.ROOT, CASTER_MODEL_ID_PATTERN, forecasterId); + } + + /** + * Gets the config id from the model id. * * @param modelId id of a model * @return id of the detector the model is for * @throws IllegalArgumentException if model id is invalid */ - public static String getDetectorIdForModelId(String modelId) { - Matcher matcher = Pattern.compile(DETECTOR_ID_PATTERN).matcher(modelId); + public static String getConfigIdForModelId(String modelId) { + Matcher matcher = Pattern.compile(CONFIG_ID_PATTERN).matcher(modelId); if (matcher.matches()) { return matcher.group(1); } else { @@ -70,7 +81,7 @@ public static String getDetectorIdForModelId(String modelId) { * @return thresholding model Id */ public static String getThresholdModelIdFromRCFModelId(String rcfModelId) { - String detectorId = getDetectorIdForModelId(rcfModelId); + String detectorId = getConfigIdForModelId(rcfModelId); return getThresholdModelId(detectorId); } } diff --git a/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java new file mode 100644 index 000000000..960234701 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ml/TimeSeriesSingleStreamCheckpointDao.java @@ -0,0 +1,16 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ml; + +public class TimeSeriesSingleStreamCheckpointDao { + +} diff --git a/src/main/java/org/opensearch/ad/model/DetectorState.java b/src/main/java/org/opensearch/timeseries/model/ConfigState.java similarity index 83% rename from src/main/java/org/opensearch/ad/model/DetectorState.java rename to src/main/java/org/opensearch/timeseries/model/ConfigState.java index a4959417b..4af52f2ee 100644 --- a/src/main/java/org/opensearch/ad/model/DetectorState.java +++ b/src/main/java/org/opensearch/timeseries/model/ConfigState.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; -public enum DetectorState { +public enum ConfigState { DISABLED, INIT, RUNNING diff --git a/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java b/src/main/java/org/opensearch/timeseries/model/Job.java similarity index 84% rename from src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java rename to src/main/java/org/opensearch/timeseries/model/Job.java index 7ef5ae528..d258279e7 100644 --- a/src/main/java/org/opensearch/ad/model/AnomalyDetectorJob.java +++ b/src/main/java/org/opensearch/timeseries/model/Job.java @@ -9,9 +9,8 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.DEFAULT_AD_JOB_LOC_DURATION_SECONDS; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import java.io.IOException; @@ -19,6 +18,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.ParseField; +import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -31,8 +31,8 @@ import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; import org.opensearch.jobscheduler.spi.schedule.Schedule; import org.opensearch.jobscheduler.spi.schedule.ScheduleParser; -import org.opensearch.timeseries.model.IntervalTimeConfiguration; -import org.opensearch.timeseries.model.TimeConfiguration; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; import com.google.common.base.Objects; @@ -40,15 +40,15 @@ /** * Anomaly detector job. */ -public class AnomalyDetectorJob implements Writeable, ToXContentObject, ScheduledJobParameter { +public class Job implements Writeable, ToXContentObject, ScheduledJobParameter { enum ScheduleType { CRON, INTERVAL } - public static final String PARSE_FIELD_NAME = "AnomalyDetectorJob"; + public static final String PARSE_FIELD_NAME = "TimeSeriesJob"; public static final NamedXContentRegistry.Entry XCONTENT_REGISTRY = new NamedXContentRegistry.Entry( - AnomalyDetectorJob.class, + Job.class, new ParseField(PARSE_FIELD_NAME), it -> parse(it) ); @@ -64,7 +64,9 @@ enum ScheduleType { public static final String DISABLED_TIME_FIELD = "disabled_time"; public static final String USER_FIELD = "user"; private static final String RESULT_INDEX_FIELD = "result_index"; + private static final String TYPE_FIELD = "type"; + // name is config id private final String name; private final Schedule schedule; private final TimeConfiguration windowDelay; @@ -75,8 +77,9 @@ enum ScheduleType { private final Long lockDurationSeconds; private final User user; private String resultIndex; + private AnalysisType analysisType; - public AnomalyDetectorJob( + public Job( String name, Schedule schedule, TimeConfiguration windowDelay, @@ -86,7 +89,8 @@ public AnomalyDetectorJob( Instant lastUpdateTime, Long lockDurationSeconds, User user, - String resultIndex + String resultIndex, + AnalysisType type ) { this.name = name; this.schedule = schedule; @@ -98,11 +102,12 @@ public AnomalyDetectorJob( this.lockDurationSeconds = lockDurationSeconds; this.user = user; this.resultIndex = resultIndex; + this.analysisType = type; } - public AnomalyDetectorJob(StreamInput input) throws IOException { + public Job(StreamInput input) throws IOException { name = input.readString(); - if (input.readEnum(AnomalyDetectorJob.ScheduleType.class) == ScheduleType.CRON) { + if (input.readEnum(Job.ScheduleType.class) == ScheduleType.CRON) { schedule = new CronSchedule(input); } else { schedule = new IntervalSchedule(input); @@ -119,6 +124,8 @@ public AnomalyDetectorJob(StreamInput input) throws IOException { user = null; } resultIndex = input.readOptionalString(); + String typeStr = input.readOptionalString(); + this.analysisType = input.readEnum(AnalysisType.class); } @Override @@ -131,7 +138,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws .field(IS_ENABLED_FIELD, isEnabled) .field(ENABLED_TIME_FIELD, enabledTime.toEpochMilli()) .field(LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()) - .field(LOCK_DURATION_SECONDS, lockDurationSeconds); + .field(LOCK_DURATION_SECONDS, lockDurationSeconds) + .field(TYPE_FIELD, analysisType); if (disabledTime != null) { xContentBuilder.field(DISABLED_TIME_FIELD, disabledTime.toEpochMilli()); } @@ -166,9 +174,10 @@ public void writeTo(StreamOutput output) throws IOException { output.writeBoolean(false); // user does not exist } output.writeOptionalString(resultIndex); + output.writeEnum(analysisType); } - public static AnomalyDetectorJob parse(XContentParser parser) throws IOException { + public static Job parse(XContentParser parser) throws IOException { String name = null; Schedule schedule = null; TimeConfiguration windowDelay = null; @@ -177,9 +186,10 @@ public static AnomalyDetectorJob parse(XContentParser parser) throws IOException Instant enabledTime = null; Instant disabledTime = null; Instant lastUpdateTime = null; - Long lockDurationSeconds = DEFAULT_AD_JOB_LOC_DURATION_SECONDS; + Long lockDurationSeconds = TimeSeriesSettings.DEFAULT_JOB_LOC_DURATION_SECONDS; User user = null; String resultIndex = null; + String analysisType = null; ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -217,12 +227,15 @@ public static AnomalyDetectorJob parse(XContentParser parser) throws IOException case RESULT_INDEX_FIELD: resultIndex = parser.text(); break; + case TYPE_FIELD: + analysisType = parser.text(); + break; default: parser.skipChildren(); break; } } - return new AnomalyDetectorJob( + return new Job( name, schedule, windowDelay, @@ -232,7 +245,10 @@ public static AnomalyDetectorJob parse(XContentParser parser) throws IOException lastUpdateTime, lockDurationSeconds, user, - resultIndex + resultIndex, + (Strings.isEmpty(analysisType) || AnalysisType.AD == AnalysisType.valueOf(analysisType)) + ? AnalysisType.AD + : AnalysisType.FORECAST ); } @@ -242,7 +258,7 @@ public boolean equals(Object o) { return true; if (o == null || getClass() != o.getClass()) return false; - AnomalyDetectorJob that = (AnomalyDetectorJob) o; + Job that = (Job) o; return Objects.equal(getName(), that.getName()) && Objects.equal(getSchedule(), that.getSchedule()) && Objects.equal(isEnabled(), that.isEnabled()) @@ -250,12 +266,13 @@ public boolean equals(Object o) { && Objects.equal(getDisabledTime(), that.getDisabledTime()) && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) && Objects.equal(getLockDurationSeconds(), that.getLockDurationSeconds()) - && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()); + && Objects.equal(getCustomResultIndex(), that.getCustomResultIndex()) + && Objects.equal(getAnalysisType(), that.getAnalysisType()); } @Override public int hashCode() { - return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime); + return Objects.hashCode(name, schedule, isEnabled, enabledTime, lastUpdateTime, analysisType); } @Override @@ -303,4 +320,8 @@ public User getUser() { public String getCustomResultIndex() { return resultIndex; } + + public AnalysisType getAnalysisType() { + return analysisType; + } } diff --git a/src/main/java/org/opensearch/ad/model/MergeableList.java b/src/main/java/org/opensearch/timeseries/model/MergeableList.java similarity index 91% rename from src/main/java/org/opensearch/ad/model/MergeableList.java rename to src/main/java/org/opensearch/timeseries/model/MergeableList.java index 4bb0d7842..fd9f26e84 100644 --- a/src/main/java/org/opensearch/ad/model/MergeableList.java +++ b/src/main/java/org/opensearch/timeseries/model/MergeableList.java @@ -9,10 +9,12 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.util.List; +import org.opensearch.ad.model.Mergeable; + public class MergeableList implements Mergeable { private final List elements; diff --git a/src/main/java/org/opensearch/ad/model/ModelProfile.java b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java similarity index 97% rename from src/main/java/org/opensearch/ad/model/ModelProfile.java rename to src/main/java/org/opensearch/timeseries/model/ModelProfile.java index 1d6d0ce85..63fdbcd02 100644 --- a/src/main/java/org/opensearch/ad/model/ModelProfile.java +++ b/src/main/java/org/opensearch/timeseries/model/ModelProfile.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.io.IOException; @@ -22,7 +22,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.model.Entity; /** * Used to show model information in profile API diff --git a/src/main/java/org/opensearch/ad/model/ADTaskState.java b/src/main/java/org/opensearch/timeseries/model/TaskState.java similarity index 68% rename from src/main/java/org/opensearch/ad/model/ADTaskState.java rename to src/main/java/org/opensearch/timeseries/model/TaskState.java index 68462f816..2b5c4240e 100644 --- a/src/main/java/org/opensearch/ad/model/ADTaskState.java +++ b/src/main/java/org/opensearch/timeseries/model/TaskState.java @@ -9,47 +9,47 @@ * GitHub history for details. */ -package org.opensearch.ad.model; +package org.opensearch.timeseries.model; import java.util.List; import com.google.common.collect.ImmutableList; /** - * AD task states. + * AD and forecasting task states. *
    *
  • CREATED: - * When user start a historical detector, we will create one task to track the detector + * AD: When user start a historical detector, we will create one task to track the detector * execution and set its state as CREATED * *
  • INIT: - * After task created, coordinate node will gather all eligible node’s state and dispatch + * AD: After task created, coordinate node will gather all eligible node’s state and dispatch * task to the worker node with lowest load. When the worker node receives the request, * it will set the task state as INIT immediately, then start to run cold start to train * RCF model. We will track the initialization progress in task. * Init_Progress=ModelUpdates/MinSampleSize * *
  • RUNNING: - * If RCF model gets enough data points and passed training, it will start to detect data + * AD: If RCF model gets enough data points and passed training, it will start to detect data * normally and output positive anomaly scores. Once the RCF model starts to output positive * anomaly score, we will set the task state as RUNNING and init progress as 100%. We will * track task running progress in task: Task_Progress=DetectedPieces/AllPieces * *
  • FINISHED: - * When all historical data detected, we set the task state as FINISHED and task progress + * AD: When all historical data detected, we set the task state as FINISHED and task progress * as 100%. * *
  • STOPPED: - * User can cancel a running task by stopping detector, for example, user want to tune + * AD: User can cancel a running task by stopping detector, for example, user want to tune * feature and reran and don’t want current task run any more. When a historical detector * stopped, we will mark the task flag cancelled as true, when run next piece, we will * check this flag and stop the task. Then task stopped, will set its state as STOPPED * *
  • FAILED: - * If any exception happen, we will set task state as FAILED + * AD: If any exception happen, we will set task state as FAILED *
*/ -public enum ADTaskState { +public enum TaskState { CREATED, INIT, RUNNING, @@ -58,5 +58,5 @@ public enum ADTaskState { FINISHED; public static List NOT_ENDED_STATES = ImmutableList - .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); } diff --git a/src/main/java/org/opensearch/timeseries/model/TaskType.java b/src/main/java/org/opensearch/timeseries/model/TaskType.java new file mode 100644 index 000000000..74481871d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/TaskType.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.model; + +import java.util.List; +import java.util.stream.Collectors; + +public interface TaskType { + String name(); + + public static List taskTypeToString(List adTaskTypes) { + return adTaskTypes.stream().map(type -> type.name()).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java new file mode 100644 index 000000000..3d729703e --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/model/TimeSeriesTask.java @@ -0,0 +1,442 @@ +package org.opensearch.timeseries.model; + +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; + +import java.io.IOException; +import java.time.Instant; + +import org.opensearch.ad.model.ADTask; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContent.Params; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.annotation.Generated; + +import com.google.common.base.Objects; + +public abstract class TimeSeriesTask implements ToXContentObject, Writeable { + + public static final String TASK_ID_FIELD = "task_id"; + public static final String LAST_UPDATE_TIME_FIELD = "last_update_time"; + public static final String STARTED_BY_FIELD = "started_by"; + public static final String STOPPED_BY_FIELD = "stopped_by"; + public static final String ERROR_FIELD = "error"; + public static final String STATE_FIELD = "state"; + public static final String TASK_PROGRESS_FIELD = "task_progress"; + public static final String INIT_PROGRESS_FIELD = "init_progress"; + public static final String CURRENT_PIECE_FIELD = "current_piece"; + public static final String EXECUTION_START_TIME_FIELD = "execution_start_time"; + public static final String EXECUTION_END_TIME_FIELD = "execution_end_time"; + public static final String IS_LATEST_FIELD = "is_latest"; + public static final String TASK_TYPE_FIELD = "task_type"; + public static final String CHECKPOINT_ID_FIELD = "checkpoint_id"; + public static final String COORDINATING_NODE_FIELD = "coordinating_node"; + public static final String WORKER_NODE_FIELD = "worker_node"; + public static final String ENTITY_FIELD = "entity"; + public static final String PARENT_TASK_ID_FIELD = "parent_task_id"; + public static final String ESTIMATED_MINUTES_LEFT_FIELD = "estimated_minutes_left"; + public static final String USER_FIELD = "user"; + public static final String HISTORICAL_TASK_PREFIX = "HISTORICAL"; + + protected String configId = null; + protected String taskId = null; + protected Instant lastUpdateTime = null; + protected String startedBy = null; + protected String stoppedBy = null; + protected String error = null; + protected String state = null; + protected Float taskProgress = null; + protected Float initProgress = null; + protected Instant currentPiece = null; + protected Instant executionStartTime = null; + protected Instant executionEndTime = null; + protected Boolean isLatest = null; + protected String taskType = null; + protected String checkpointId = null; + protected String coordinatingNode = null; + protected String workerNode = null; + protected Entity entity = null; + protected String parentTaskId = null; + protected Integer estimatedMinutesLeft = null; + protected User user = null; + + @SuppressWarnings("unchecked") + public abstract static class Builder> { + protected String configId = null; + protected String taskId = null; + protected String taskType = null; + protected String state = null; + protected Float taskProgress = null; + protected Float initProgress = null; + protected Instant currentPiece = null; + protected Instant executionStartTime = null; + protected Instant executionEndTime = null; + protected Boolean isLatest = null; + protected String error = null; + protected String checkpointId = null; + protected Instant lastUpdateTime = null; + protected String startedBy = null; + protected String stoppedBy = null; + protected String coordinatingNode = null; + protected String workerNode = null; + protected Entity entity = null; + protected String parentTaskId; + protected Integer estimatedMinutesLeft; + protected User user = null; + + public Builder() {} + + public T configId(String configId) { + this.configId = configId; + return (T) this; + } + + public T taskId(String taskId) { + this.taskId = taskId; + return (T) this; + } + + public T lastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + return (T) this; + } + + public T startedBy(String startedBy) { + this.startedBy = startedBy; + return (T) this; + } + + public T stoppedBy(String stoppedBy) { + this.stoppedBy = stoppedBy; + return (T) this; + } + + public T error(String error) { + this.error = error; + return (T) this; + } + + public T state(String state) { + this.state = state; + return (T) this; + } + + public T taskProgress(Float taskProgress) { + this.taskProgress = taskProgress; + return (T) this; + } + + public T initProgress(Float initProgress) { + this.initProgress = initProgress; + return (T) this; + } + + public T currentPiece(Instant currentPiece) { + this.currentPiece = currentPiece; + return (T) this; + } + + public T executionStartTime(Instant executionStartTime) { + this.executionStartTime = executionStartTime; + return (T) this; + } + + public T executionEndTime(Instant executionEndTime) { + this.executionEndTime = executionEndTime; + return (T) this; + } + + public T isLatest(Boolean isLatest) { + this.isLatest = isLatest; + return (T) this; + } + + public T taskType(String taskType) { + this.taskType = taskType; + return (T) this; + } + + public T checkpointId(String checkpointId) { + this.checkpointId = checkpointId; + return (T) this; + } + + public T coordinatingNode(String coordinatingNode) { + this.coordinatingNode = coordinatingNode; + return (T) this; + } + + public T workerNode(String workerNode) { + this.workerNode = workerNode; + return (T) this; + } + + public T entity(Entity entity) { + this.entity = entity; + return (T) this; + } + + public T parentTaskId(String parentTaskId) { + this.parentTaskId = parentTaskId; + return (T) this; + } + + public T estimatedMinutesLeft(Integer estimatedMinutesLeft) { + this.estimatedMinutesLeft = estimatedMinutesLeft; + return (T) this; + } + + public T user(User user) { + this.user = user; + return (T) this; + } + } + + public boolean isHistoricalTask() { + return taskType.startsWith(TimeSeriesTask.HISTORICAL_TASK_PREFIX); + } + + /** + * Get detector level task id. If a task has no parent task, the task is detector level task. + * @return detector level task id + */ + public String getDetectorLevelTaskId() { + return getParentTaskId() != null ? getParentTaskId() : getTaskId(); + } + + public String getTaskId() { + return taskId; + } + + public void setTaskId(String taskId) { + this.taskId = taskId; + } + + public Instant getLastUpdateTime() { + return lastUpdateTime; + } + + public String getStartedBy() { + return startedBy; + } + + public String getStoppedBy() { + return stoppedBy; + } + + public String getError() { + return error; + } + + public void setError(String error) { + this.error = error; + } + + public String getState() { + return state; + } + + public void setState(String state) { + this.state = state; + } + + public Float getTaskProgress() { + return taskProgress; + } + + public Float getInitProgress() { + return initProgress; + } + + public Instant getCurrentPiece() { + return currentPiece; + } + + public Instant getExecutionStartTime() { + return executionStartTime; + } + + public Instant getExecutionEndTime() { + return executionEndTime; + } + + public Boolean getLatest() { + return isLatest; + } + + public String getTaskType() { + return taskType; + } + + public String getCheckpointId() { + return checkpointId; + } + + public String getCoordinatingNode() { + return coordinatingNode; + } + + public String getWorkerNode() { + return workerNode; + } + + public Entity getEntity() { + return entity; + } + + public String getParentTaskId() { + return parentTaskId; + } + + public Integer getEstimatedMinutesLeft() { + return estimatedMinutesLeft; + } + + public User getUser() { + return user; + } + + public String getConfigId() { + return configId; + } + + public void setLatest(Boolean latest) { + isLatest = latest; + } + + public void setLastUpdateTime(Instant lastUpdateTime) { + this.lastUpdateTime = lastUpdateTime; + } + + public boolean isDone() { + return !NOT_ENDED_STATES.contains(this.getState()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (taskId != null) { + builder.field(TimeSeriesTask.TASK_ID_FIELD, taskId); + } + if (lastUpdateTime != null) { + builder.field(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, lastUpdateTime.toEpochMilli()); + } + if (startedBy != null) { + builder.field(TimeSeriesTask.STARTED_BY_FIELD, startedBy); + } + if (stoppedBy != null) { + builder.field(TimeSeriesTask.STOPPED_BY_FIELD, stoppedBy); + } + if (error != null) { + builder.field(TimeSeriesTask.ERROR_FIELD, error); + } + if (state != null) { + builder.field(TimeSeriesTask.STATE_FIELD, state); + } + if (taskProgress != null) { + builder.field(TimeSeriesTask.TASK_PROGRESS_FIELD, taskProgress); + } + if (initProgress != null) { + builder.field(TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress); + } + if (currentPiece != null) { + builder.field(TimeSeriesTask.CURRENT_PIECE_FIELD, currentPiece.toEpochMilli()); + } + if (executionStartTime != null) { + builder.field(TimeSeriesTask.EXECUTION_START_TIME_FIELD, executionStartTime.toEpochMilli()); + } + if (executionEndTime != null) { + builder.field(TimeSeriesTask.EXECUTION_END_TIME_FIELD, executionEndTime.toEpochMilli()); + } + if (isLatest != null) { + builder.field(TimeSeriesTask.IS_LATEST_FIELD, isLatest); + } + if (taskType != null) { + builder.field(TimeSeriesTask.TASK_TYPE_FIELD, taskType); + } + if (checkpointId != null) { + builder.field(TimeSeriesTask.CHECKPOINT_ID_FIELD, checkpointId); + } + if (coordinatingNode != null) { + builder.field(TimeSeriesTask.COORDINATING_NODE_FIELD, coordinatingNode); + } + if (workerNode != null) { + builder.field(TimeSeriesTask.WORKER_NODE_FIELD, workerNode); + } + if (entity != null) { + builder.field(TimeSeriesTask.ENTITY_FIELD, entity); + } + if (parentTaskId != null) { + builder.field(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId); + } + if (estimatedMinutesLeft != null) { + builder.field(TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD, estimatedMinutesLeft); + } + if (user != null) { + builder.field(TimeSeriesTask.USER_FIELD, user); + } + return builder; + } + + @Generated + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ADTask that = (ADTask) o; + return Objects.equal(getTaskId(), that.getTaskId()) + && Objects.equal(getLastUpdateTime(), that.getLastUpdateTime()) + && Objects.equal(getStartedBy(), that.getStartedBy()) + && Objects.equal(getStoppedBy(), that.getStoppedBy()) + && Objects.equal(getError(), that.getError()) + && Objects.equal(getState(), that.getState()) + && Objects.equal(getTaskProgress(), that.getTaskProgress()) + && Objects.equal(getInitProgress(), that.getInitProgress()) + && Objects.equal(getCurrentPiece(), that.getCurrentPiece()) + && Objects.equal(getExecutionStartTime(), that.getExecutionStartTime()) + && Objects.equal(getExecutionEndTime(), that.getExecutionEndTime()) + && Objects.equal(getLatest(), that.getLatest()) + && Objects.equal(getTaskType(), that.getTaskType()) + && Objects.equal(getCheckpointId(), that.getCheckpointId()) + && Objects.equal(getCoordinatingNode(), that.getCoordinatingNode()) + && Objects.equal(getWorkerNode(), that.getWorkerNode()) + && Objects.equal(getEntity(), that.getEntity()) + && Objects.equal(getParentTaskId(), that.getParentTaskId()) + && Objects.equal(getEstimatedMinutesLeft(), that.getEstimatedMinutesLeft()) + && Objects.equal(getUser(), that.getUser()); + } + + @Generated + @Override + public int hashCode() { + return Objects + .hashCode( + taskId, + lastUpdateTime, + startedBy, + stoppedBy, + error, + state, + taskProgress, + initProgress, + currentPiece, + executionStartTime, + executionEndTime, + isLatest, + taskType, + checkpointId, + coordinatingNode, + workerNode, + entity, + parentTaskId, + estimatedMinutesLeft, + user + ); + } + + public abstract boolean isEntityTask(); + + public abstract String getEntityModelId(); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java similarity index 92% rename from src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java index 50b051f6d..29836b7f1 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/BatchWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/BatchWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -20,13 +20,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ThreadedActionListener; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; /** * @@ -46,7 +47,7 @@ public BatchWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -58,7 +59,8 @@ public BatchWorker( Duration executionTtl, Setting batchSizeSetting, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager timeSeriesNodeStateManager, + AnalysisType context ) { super( queueName, @@ -78,7 +80,8 @@ public BatchWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + timeSeriesNodeStateManager, + context ); this.batchSize = batchSizeSetting.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(batchSizeSetting, it -> batchSize = it); diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java similarity index 71% rename from src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java index 91382a4b5..788d2e370 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapter.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckPointMaintainRequestAdapter.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -21,35 +21,42 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.util.DateUtils; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.common.Strings; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.util.DateUtils; -public class CheckPointMaintainRequestAdapter { +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Convert from ModelRequest to CheckpointWriteRequest + * + */ +public class CheckPointMaintainRequestAdapter & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CacheType extends TimeSeriesCache> { private static final Logger LOG = LogManager.getLogger(CheckPointMaintainRequestAdapter.class); - private CacheProvider cache; - private CheckpointDao checkpointDao; + private CheckpointDaoType checkpointDao; private String indexName; private Duration checkpointInterval; private Clock clock; + private Provider cache; public CheckPointMaintainRequestAdapter( - CacheProvider cache, - CheckpointDao checkpointDao, + CheckpointDaoType checkpointDao, String indexName, Setting checkpointIntervalSetting, Clock clock, ClusterService clusterService, - Settings settings + Settings settings, + Provider cache ) { - this.cache = cache; this.checkpointDao = checkpointDao; this.indexName = indexName; @@ -59,15 +66,16 @@ public CheckPointMaintainRequestAdapter( .addSettingsUpdateConsumer(checkpointIntervalSetting, it -> this.checkpointInterval = DateUtils.toDuration(it)); this.clock = clock; + this.cache = cache; } public Optional convert(CheckpointMaintainRequest request) { - String detectorId = request.getId(); - String modelId = request.getEntityModelId(); + String configId = request.getConfigId(); + String modelId = request.getModelId(); - Optional> stateToMaintain = cache.get().getForMaintainance(detectorId, modelId); - if (!stateToMaintain.isEmpty()) { - ModelState state = stateToMaintain.get(); + Optional> stateToMaintain = cache.get().getForMaintainance(configId, modelId); + if (stateToMaintain.isPresent()) { + ModelState state = stateToMaintain.get(); Instant instant = state.getLastCheckpointTime(); if (!checkpointDao.shouldSave(instant, false, checkpointInterval, clock)) { return Optional.empty(); @@ -85,7 +93,7 @@ public Optional convert(CheckpointMaintainRequest reques .of( new CheckpointWriteRequest( request.getExpirationEpochMs(), - detectorId, + configId, request.getPriority(), // If the document does not already exist, the contents of the upsert element // are inserted as a new document. diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java similarity index 58% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java index 28fdfcc91..479965240 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointMaintainRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainRequest.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public class CheckpointMaintainRequest extends QueuedRequest { - private String entityModelId; + private String modelId; - public CheckpointMaintainRequest(long expirationEpochMs, String detectorId, RequestPriority priority, String entityModelId) { - super(expirationEpochMs, detectorId, priority); - this.entityModelId = entityModelId; + public CheckpointMaintainRequest(long expirationEpochMs, String configId, RequestPriority priority, String entityModelId) { + super(expirationEpochMs, configId, priority); + this.modelId = entityModelId; } - public String getEntityModelId() { - return entityModelId; + public String getModelId() { + return modelId; } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java new file mode 100644 index 000000000..d6a1f48e1 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointMaintainWorker.java @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.function.Function; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CheckpointMaintainWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao> + extends ScheduledWorker { + + private Function> converter; + + public CheckpointMaintainWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + RateLimitedRequestWorker targetQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Function> converter, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + targetQueue, + stateTtl, + nodeStateManager, + context + ); + this.converter = converter; + } + + @Override + protected List transformRequests(List requests) { + List allRequests = new ArrayList<>(); + for (CheckpointMaintainRequest request : requests) { + Optional converted = converter.apply(request); + if (!converted.isEmpty()) { + allRequests.add(converted.get()); + } + } + return allRequests; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java similarity index 57% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java index e06d3e08e..8b81eb20a 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointReadWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointReadWorker.java @@ -9,10 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -21,7 +18,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Random; import java.util.Set; @@ -33,63 +29,61 @@ import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetRequest; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.indices.ADIndex; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Provider; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; import org.opensearch.timeseries.stats.StatNames; -import org.opensearch.timeseries.util.ParseUtils; - -/** - * a queue for loading model checkpoint. The read is a multi-get query. Possible results are: - * a). If a checkpoint is not found, we forward that request to the cold start queue. - * b). When a request gets errors, the queue does not change its expiry time and puts - * that request to the end of the queue and automatically retries them before they expire. - * c) When a checkpoint is found, we load that point to memory and score the input - * data point and save the result if a complete model exists. Otherwise, we enqueue - * the sample. If we can host that model in memory (e.g., there is enough memory), - * we put the loaded model to cache. Otherwise (e.g., a cold entity), we write the - * updated checkpoint back to disk. - * - */ -public class CheckpointReadWorker extends BatchWorker { +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.ResultBulkRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class CheckpointReadWorker, ResultWriteBatchRequestType extends ResultBulkRequest, RCFResultType extends IntermediateResult, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker, ResultHandlerType extends IndexMemoryPressureAwareResultHandler, ResultWriteWorkerType extends ResultWriteWorker> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class); - public static final String WORKER_NAME = "checkpoint-read"; - private final ModelManager modelManager; - private final CheckpointDao checkpointDao; - private final EntityColdStartWorker entityColdStartQueue; - private final ResultWriteWorker resultWriteQueue; - private final ADIndexManagement indexUtil; - private final CacheProvider cacheProvider; - private final CheckpointWriteWorker checkpointWriteQueue; - private final ADStats adStats; + + protected final ModelManagerType modelManager; + protected final CheckpointType checkpointDao; + protected final ColdStartWorkerType entityColdStartWorker; + protected final ResultWriteWorkerType resultWriteWorker; + protected final IndexManagementType indexUtil; + protected final Stats timeSeriesStats; + protected final CheckpointWriteWorkerType checkpointWriteWorker; + protected final Provider> cacheProvider; + protected final String checkpointIndexName; + protected final StatNames modelCorruptionStat; public CheckpointReadWorker( + String workerName, long heapSizeInBytes, int singleRequestSizeInBytes, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -98,19 +92,24 @@ public CheckpointReadWorker( float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, - ModelManager modelManager, - CheckpointDao checkpointDao, - EntityColdStartWorker entityColdStartQueue, - ResultWriteWorker resultWriteQueue, + ModelManagerType modelManager, + CheckpointType checkpointDao, + ColdStartWorkerType entityColdStartWorker, + ResultWriteWorkerType resultWriteWorker, NodeStateManager stateManager, - ADIndexManagement indexUtil, - CacheProvider cacheProvider, + IndexManagementType indexUtil, + Provider> cacheProvider, Duration stateTtl, - CheckpointWriteWorker checkpointWriteQueue, - ADStats adStats + CheckpointWriteWorkerType checkpointWriteWorker, + Stats timeSeriesStats, + Setting concurrencySetting, + Setting batchSizeSetting, + String checkpointIndexName, + StatNames modelCorruptionStat, + AnalysisType context ) { super( - WORKER_NAME, + workerName, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, @@ -124,21 +123,24 @@ public CheckpointReadWorker( mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + stateManager, + context ); this.modelManager = modelManager; this.checkpointDao = checkpointDao; - this.entityColdStartQueue = entityColdStartQueue; - this.resultWriteQueue = resultWriteQueue; + this.entityColdStartWorker = entityColdStartWorker; + this.resultWriteWorker = resultWriteWorker; this.indexUtil = indexUtil; this.cacheProvider = cacheProvider; - this.checkpointWriteQueue = checkpointWriteQueue; - this.adStats = adStats; + this.checkpointWriteWorker = checkpointWriteWorker; + this.timeSeriesStats = timeSeriesStats; + this.checkpointIndexName = checkpointIndexName; + this.modelCorruptionStat = modelCorruptionStat; } @Override @@ -154,20 +156,20 @@ protected void executeBatchRequest(MultiGetRequest request, ActionListener toProcess) { + protected MultiGetRequest toBatchRequest(List toProcess) { MultiGetRequest multiGetRequest = new MultiGetRequest(); - for (EntityRequest request : toProcess) { - Optional modelId = request.getModelId(); - if (false == modelId.isPresent()) { + for (FeatureRequest request : toProcess) { + String modelId = request.getModelId(); + if (null == modelId) { continue; } - multiGetRequest.add(new MultiGetRequest.Item(ADCommonName.CHECKPOINT_INDEX_NAME, modelId.get())); + multiGetRequest.add(new MultiGetRequest.Item(checkpointIndexName, modelId)); } return multiGetRequest; } @Override - protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { + protected ActionListener getResponseListener(List toProcess, MultiGetRequest batchRequest) { return ActionListener.wrap(response -> { final MultiGetItemResponse[] itemResponses = response.getResponses(); Map successfulRequests = new HashMap<>(); @@ -184,11 +186,11 @@ protected ActionListener getResponseListener(List getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && notFoundModels.contains(modelId.get())) { + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && notFoundModels.contains(modelId)) { // submit to cold start queue - entityColdStartQueue.put(origRequest); + entityColdStartWorker.put(origRequest); } } } @@ -239,15 +241,17 @@ protected ActionListener getResponseListener(List modelId = origRequest.getModelId(); - if (modelId.isPresent() && stopDetectorRequests.containsKey(modelId.get())) { - String adID = origRequest.detectorId; + for (FeatureRequest origRequest : toProcess) { + String modelId = origRequest.getModelId(); + if (modelId != null && stopDetectorRequests.containsKey(modelId)) { + String configID = origRequest.getConfigId(); nodeStateManager .setException( - adID, - new EndRunException(adID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId.get()), false) + configID, + new EndRunException(configID, CommonMessages.BUG_RESPONSE, stopDetectorRequests.get(modelId), false) ); + // once one EndRunException is set, we can break; no point setting the exception repeatedly + break; } } } @@ -260,7 +264,7 @@ protected ActionListener getResponseListener(List { if (ExceptionUtil.isOverloaded(exception)) { - LOG.error("too many get AD model checkpoint requests or shard not available"); + LOG.error("too many get model checkpoint requests or shard not available"); setCoolDownStart(); } else if (ExceptionUtil.isRetryAble(exception)) { // retry all of them @@ -271,9 +275,9 @@ protected ActionListener getResponseListener(List toProcess, + List toProcess, Map successfulRequests, Set retryableRequests ) { @@ -285,41 +289,41 @@ private void processCheckpointIteration( // if false, finally will process next checkpoints boolean processNextInCallBack = false; try { - EntityFeatureRequest origRequest = toProcess.get(i); + FeatureRequest origRequest = toProcess.get(i); - Optional modelIdOptional = origRequest.getModelId(); - if (false == modelIdOptional.isPresent()) { + String modelId = origRequest.getModelId(); + if (null == modelId) { return; } - String detectorId = origRequest.getId(); - Entity entity = origRequest.getEntity(); - - String modelId = modelIdOptional.get(); + String configId = origRequest.getConfigId(); + Optional entity = origRequest.getEntity(); MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId); if (checkpointResponse != null) { // successful requests - Optional> checkpoint = checkpointDao - .processGetResponse(checkpointResponse.getResponse(), modelId); + ModelState modelState = checkpointDao + .processHCGetResponse(checkpointResponse.getResponse(), modelId, configId); - if (false == checkpoint.isPresent()) { - // checkpoint is too big + if (null == modelState) { + // checkpoint is not available (e.g., too big or corrupted); cold start again + entityColdStartWorker.put(origRequest); return; } nodeStateManager - .getAnomalyDetector( - detectorId, - onGetDetector( + .getConfig( + configId, + context, + processIterationUsingConfig( origRequest, i, - detectorId, + configId, toProcess, successfulRequests, retryableRequests, - checkpoint, + modelState, entity, modelId ) @@ -336,39 +340,46 @@ private void processCheckpointIteration( } } - private ActionListener> onGetDetector( - EntityFeatureRequest origRequest, + protected ActionListener> processIterationUsingConfig( + FeatureRequest origRequest, int index, - String detectorId, - List toProcess, + String configId, + List toProcess, Map successfulRequests, Set retryableRequests, - Optional> checkpoint, - Entity entity, + ModelState restoredModelState, + Optional entity, String modelId ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - AnomalyDetector detector = detectorOptional.get(); + Config config = configOptional.get(); - ModelState modelState = modelManager - .processEntityCheckpoint(checkpoint, entity, modelId, detectorId, detector.getShingleSize()); - - ThresholdingResult result = null; + RCFResultType result = null; try { result = modelManager - .getAnomalyResultForEntity(origRequest.getCurrentFeature(), modelState, modelId, entity, detector.getShingleSize()); + .getResult( + new Sample( + origRequest.getCurrentFeature(), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), + Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds()) + ), + restoredModelState, + modelId, + entity, + config + ); } catch (IllegalArgumentException e) { // fail to score likely due to model corruption. Re-cold start to recover. LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", origRequest.getModelId()), e); - adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).increment(); - if (origRequest.getModelId().isPresent()) { - String entityModelId = origRequest.getModelId().get(); + timeSeriesStats.getStat(modelCorruptionStat.getName()).increment(); + if (null != origRequest.getModelId()) { + String entityModelId = origRequest.getModelId(); checkpointDao .deleteModelCheckpoint( entityModelId, @@ -380,57 +391,35 @@ private ActionListener> onGetDetector( ); } - entityColdStartQueue.put(origRequest); + entityColdStartWorker.put(origRequest); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); return; } - if (result != null && result.getRcfScore() > 0) { - RequestPriority requestPriority = result.getGrade() > 0 ? RequestPriority.HIGH : RequestPriority.MEDIUM; - - List resultsToSave = result - .toIndexableResults( - detector, - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), - Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + detector.getIntervalInMilliseconds()), - Instant.now(), - Instant.now(), - ParseUtils.getFeatureData(origRequest.getCurrentFeature(), detector), - Optional.ofNullable(entity), - indexUtil.getSchemaVersion(ADIndex.RESULT), - modelId, - null, - null - ); - - for (AnomalyResult r : resultsToSave) { - resultWriteQueue - .put( - new ResultWriteRequest( - origRequest.getExpirationEpochMs(), - detectorId, - requestPriority, - r, - detector.getCustomResultIndex() - ) - ); - } - } + saveResult(result, config, origRequest, entity, modelId); // try to load to cache - boolean loaded = cacheProvider.get().hostIfPossible(detector, modelState); + boolean loaded = cacheProvider.get().hostIfPossible(config, restoredModelState); if (false == loaded) { // not in memory. Maybe cold entities or some other entities // have filled the slot while waiting for loading checkpoints. - checkpointWriteQueue.write(modelState, true, RequestPriority.LOW); + checkpointWriteWorker.write(restoredModelState, true, RequestPriority.LOW); } processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }, exception -> { LOG.error(new ParameterizedMessage("fail to get checkpoint [{}]", modelId, exception)); - nodeStateManager.setException(detectorId, exception); + nodeStateManager.setException(configId, exception); processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests); }); } + + protected abstract void saveResult( + RCFResultType result, + Config config, + FeatureRequest origRequest, + Optional entity, + String modelId + ); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java similarity index 94% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java index 9c41e55be..02a374f82 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import org.opensearch.action.update.UpdateRequest; diff --git a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java similarity index 74% rename from src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java index dd32e21c4..f1b23654b 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/CheckpointWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/CheckpointWriteWorker.java @@ -1,18 +1,9 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -31,34 +22,37 @@ import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.Strings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.util.ExceptionUtil; -public class CheckpointWriteWorker extends BatchWorker { +public abstract class CheckpointWriteWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class); - public static final String WORKER_NAME = "checkpoint-write"; - private final CheckpointDao checkpoint; - private final String indexName; - private final Duration checkpointInterval; + protected final CheckpointDaoType checkpoint; + protected final String indexName; + protected final Duration checkpointInterval; public CheckpointWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -66,17 +60,20 @@ public CheckpointWriteWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - CheckpointDao checkpoint, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + CheckpointDaoType checkpoint, String indexName, Duration checkpointInterval, - NodeStateManager stateManager, - Duration stateTtl + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, @@ -88,11 +85,12 @@ public CheckpointWriteWorker( mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.checkpoint = checkpoint; this.indexName = indexName; @@ -131,7 +129,7 @@ protected ActionListener getResponseListener(List getResponseListener(List modelState, boolean forceWrite, RequestPriority priority) { + public void write(ModelState modelState, boolean forceWrite, RequestPriority priority) { Instant instant = modelState.getLastCheckpointTime(); if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { return; } if (modelState.getModel() != null) { - String detectorId = modelState.getId(); + String configId = modelState.getConfigId(); String modelId = modelState.getModelId(); - if (modelId == null || detectorId == null) { + if (modelId == null || configId == null) { return; } - nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(detectorId, modelId, modelState, priority)); + nodeStateManager.getConfig(configId, context, onGetConfig(configId, modelId, modelState, priority)); } } - private ActionListener> onGetDetector( - String detectorId, + private ActionListener> onGetConfig( + String configId, String modelId, - ModelState modelState, + ModelState modelState, RequestPriority priority ) { return ActionListener.wrap(detectorOptional -> { if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", configId)); return; } - AnomalyDetector detector = detectorOptional.get(); + Config detector = detectorOptional.get(); try { Map source = checkpoint.toIndexSource(modelState); @@ -191,7 +189,7 @@ private ActionListener> onGetDetector( modelState.setLastCheckpointTime(clock.instant()); CheckpointWriteRequest request = new CheckpointWriteRequest( System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -212,20 +210,20 @@ private ActionListener> onGetDetector( LOG.error(new ParameterizedMessage("Exception while serializing models for [{}]", modelId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get config [{}]", configId), exception); }); } - public void writeAll(List> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) { - ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { + public void writeAll(List> modelStates, String configId, boolean forceWrite, RequestPriority priority) { + ActionListener> onGetForAll = ActionListener.wrap(detectorOptional -> { if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); + LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", configId)); return; } - AnomalyDetector detector = detectorOptional.get(); + Config detector = detectorOptional.get(); try { List allRequests = new ArrayList<>(); - for (ModelState state : modelStates) { + for (ModelState state : modelStates) { Instant instant = state.getLastCheckpointTime(); if (!checkpoint.shouldSave(instant, forceWrite, checkpointInterval, clock)) { continue; @@ -244,7 +242,7 @@ public void writeAll(List> modelStates, String detectorI .add( new CheckpointWriteRequest( System.currentTimeMillis() + detector.getIntervalInMilliseconds(), - detectorId, + configId, priority, // If the document does not already exist, the contents of the upsert element // are inserted as a new document. @@ -264,11 +262,11 @@ public void writeAll(List> modelStates, String detectorI // As we are gonna retry serializing either when the entity is // evicted out of cache or during the next maintenance period, // don't do anything when the exception happens. - LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", detectorId), e); + LOG.info(new ParameterizedMessage("Exception while serializing models for [{}]", configId), e); } - }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); }); + }, exception -> { LOG.error(new ParameterizedMessage("fail to get detector [{}]", configId), exception); }); - nodeStateManager.getAnomalyDetector(detectorId, onGetForAll); + nodeStateManager.getConfig(configId, context, onGetForAll); } } diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java new file mode 100644 index 000000000..acaaa820b --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdEntityWorker.java @@ -0,0 +1,101 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.transport.ResultBulkRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class ColdEntityWorker, ResultWriteBatchRequestType extends ResultBulkRequest, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, RCFResultType extends IntermediateResult, ModelManagerType extends ModelManager, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker, ResultHandlerType extends IndexMemoryPressureAwareResultHandler, ResultWriteWorkerType extends ResultWriteWorker, CheckpointReadWorkerType extends CheckpointReadWorker> + extends ScheduledWorker { + + public ColdEntityWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + CheckpointReadWorkerType checkpointReadQueue, + Duration stateTtl, + NodeStateManager nodeStateManager, + Setting checkpointReadBatchSizeSetting, + Setting expectedColdEntityExecutionMillsSetting, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + checkpointReadQueue, + stateTtl, + nodeStateManager, + context + ); + + this.batchSize = checkpointReadBatchSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(checkpointReadBatchSizeSetting, it -> this.batchSize = it); + + this.expectedExecutionTimeInMilliSecsPerRequest = expectedColdEntityExecutionMillsSetting.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(expectedColdEntityExecutionMillsSetting, it -> this.expectedExecutionTimeInMilliSecsPerRequest = it); + } + + @Override + protected List transformRequests(List requests) { + // guarantee we only send low priority requests + return requests.stream().filter(request -> request.getPriority() == RequestPriority.LOW).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java new file mode 100644 index 000000000..6ab17677f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ColdStartWorker.java @@ -0,0 +1,147 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.time.Clock; +import java.time.Duration; +import java.util.Locale; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.util.ExceptionUtil; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public abstract class ColdStartWorker & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ColdStarterType extends ModelColdStart, CacheType extends TimeSeriesCache> + extends SingleRequestWorker { + private static final Logger LOG = LogManager.getLogger(ColdStartWorker.class); + + protected final ColdStarterType coldStarter; + protected final CacheType cacheProvider; + + public ColdStartWorker( + String workerName, + long heapSizeInBytes, + int singleRequestSizeInBytes, + Setting maxHeapPercentForQueueSetting, + ClusterService clusterService, + Random random, + CircuitBreakerService adCircuitBreakerService, + ThreadPool threadPool, + Settings settings, + float maxQueuedTaskRatio, + Clock clock, + float mediumSegmentPruneRatio, + float lowSegmentPruneRatio, + int maintenanceFreqConstant, + Setting concurrency, + Duration executionTtl, + ColdStarterType coldStarter, + Duration stateTtl, + NodeStateManager nodeStateManager, + CacheType cacheProvider, + AnalysisType context + ) { + super( + workerName, + heapSizeInBytes, + singleRequestSizeInBytes, + maxHeapPercentForQueueSetting, + clusterService, + random, + adCircuitBreakerService, + threadPool, + settings, + maxQueuedTaskRatio, + clock, + mediumSegmentPruneRatio, + lowSegmentPruneRatio, + maintenanceFreqConstant, + concurrency, + executionTtl, + stateTtl, + nodeStateManager, + context + ); + this.coldStarter = coldStarter; + this.cacheProvider = cacheProvider; + + } + + @Override + protected void executeRequest(FeatureRequest coldStartRequest, ActionListener listener) { + String configId = coldStartRequest.getConfigId(); + + String modelId = coldStartRequest.getModelId(); + + if (null == modelId) { + String error = String.format(Locale.ROOT, "Fail to get model id for request %s", coldStartRequest); + LOG.warn(error); + listener.onFailure(new RuntimeException(error)); + return; + } + + ModelState modelState = createEmptyState(coldStartRequest, modelId, configId); + + ActionListener coldStartListener = ActionListener.wrap(r -> { + nodeStateManager.getConfig(configId, context, ActionListener.wrap(configOptional -> { + try { + if (!configOptional.isPresent()) { + LOG + .error( + new ParameterizedMessage( + "fail to load trained model [{}] to cache due to the config not being found.", + modelState.getModelId() + ) + ); + return; + } + cacheProvider.hostIfPossible(configOptional.get(), modelState); + + } finally { + listener.onResponse(null); + } + }, listener::onFailure)); + + }, e -> { + try { + if (ExceptionUtil.isOverloaded(e)) { + LOG.error("OpenSearch is overloaded"); + setCoolDownStart(); + } + nodeStateManager.setException(configId, e); + } finally { + listener.onFailure(e); + } + }); + + coldStarter.trainModel(coldStartRequest.getEntity(), configId, modelState, coldStartListener); + } + + protected abstract ModelState createEmptyState(FeatureRequest coldStartRequest, String modelId, String configId); +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java similarity index 94% rename from src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java index 62bd0a2bd..bc81b1650 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ConcurrentWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ConcurrentWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -19,13 +19,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; /** * A queue to run concurrent requests (either batch or single request). @@ -74,7 +75,7 @@ public ConcurrentWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -85,7 +86,8 @@ public ConcurrentWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -103,7 +105,8 @@ public ConcurrentWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.permits = new Semaphore(concurrencySetting.get(settings)); diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java new file mode 100644 index 000000000..65cd02a69 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/FeatureRequest.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.util.Optional; + +import org.opensearch.timeseries.model.Entity; + +public class FeatureRequest extends QueuedRequest { + private final double[] currentFeature; + private final long dataStartTimeMillis; + protected final String modelId; + private final Optional entity; + + // used in HC + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + double[] currentFeature, + long dataStartTimeMs, + Entity entity + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = entity.getModelId(configId).isEmpty() ? null : entity.getModelId(configId).get(); + this.entity = Optional.ofNullable(entity); + } + + // used in single-stream + public FeatureRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + String modelId, + double[] currentFeature, + long dataStartTimeMs + ) { + super(expirationEpochMs, configId, priority); + this.currentFeature = currentFeature; + this.dataStartTimeMillis = dataStartTimeMs; + this.modelId = modelId; + this.entity = Optional.empty(); + } + + public double[] getCurrentFeature() { + return currentFeature; + } + + public long getDataStartTimeMillis() { + return dataStartTimeMillis; + } + + public String getModelId() { + return modelId; + } + + public Optional getEntity() { + return entity; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java rename to src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java index 66c440db9..a13a490de 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/QueuedRequest.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/QueuedRequest.java @@ -9,22 +9,22 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public abstract class QueuedRequest { protected long expirationEpochMs; - protected String detectorId; + protected String configId; protected RequestPriority priority; /** * * @param expirationEpochMs Request expiry time in milliseconds - * @param detectorId Detector Id + * @param configId Detector Id * @param priority how urgent the request is */ - protected QueuedRequest(long expirationEpochMs, String detectorId, RequestPriority priority) { + protected QueuedRequest(long expirationEpochMs, String configId, RequestPriority priority) { this.expirationEpochMs = expirationEpochMs; - this.detectorId = detectorId; + this.configId = configId; this.priority = priority; } @@ -47,11 +47,11 @@ public void setPriority(RequestPriority priority) { this.priority = priority; } - public String getId() { - return detectorId; + public String getConfigId() { + return configId; } public void setDetectorId(String detectorId) { - this.detectorId = detectorId; + this.configId = detectorId; } } diff --git a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java similarity index 96% rename from src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java index 770c79b96..27bd23630 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RateLimitedRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RateLimitedRequestWorker.java @@ -9,9 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.time.Clock; import java.time.Duration; @@ -33,17 +33,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.opensearch.ad.ExpiringState; -import org.opensearch.ad.MaintenanceState; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPoolStats; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExpiringState; +import org.opensearch.timeseries.MaintenanceState; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.TimeSeriesException; /** @@ -175,7 +176,7 @@ public int clearExpiredRequests() { protected final ConcurrentSkipListMap requestQueues; private String lastSelectedRequestQueueId; protected Random random; - private ADCircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService adCircuitBreakerService; protected ThreadPool threadPool; protected Instant cooldownStart; protected int coolDownMinutes; @@ -186,6 +187,7 @@ public int clearExpiredRequests() { protected int maintenanceFreqConstant; private final Duration stateTtl; protected final NodeStateManager nodeStateManager; + protected final AnalysisType context; public RateLimitedRequestWorker( String workerName, @@ -194,7 +196,7 @@ public RateLimitedRequestWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -203,7 +205,8 @@ public RateLimitedRequestWorker( float lowRequestQueuePruneRatio, int maintenanceFreqConstant, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { this.heapSize = heapSizeInBytes; this.singleRequestSize = singleRequestSizeInBytes; @@ -228,10 +231,11 @@ public RateLimitedRequestWorker( this.lastSelectedRequestQueueId = null; this.requestQueues = new ConcurrentSkipListMap<>(); this.cooldownStart = Instant.MIN; - this.coolDownMinutes = (int) (COOLDOWN_MINUTES.get(settings).getMinutes()); + this.coolDownMinutes = (int) (AD_COOLDOWN_MINUTES.get(settings).getMinutes()); this.maintenanceFreqConstant = maintenanceFreqConstant; this.stateTtl = stateTtl; this.nodeStateManager = nodeStateManager; + this.context = context; } protected String getWorkerName() { @@ -305,7 +309,7 @@ protected void putOnly(RequestType request) { // just use the RequestQueue priority (i.e., low or high) as the key of the RequestQueue map. RequestQueue requestQueue = requestQueues .computeIfAbsent( - RequestPriority.MEDIUM == request.getPriority() ? request.getId() : request.getPriority().name(), + RequestPriority.MEDIUM == request.getPriority() ? request.getConfigId() : request.getPriority().name(), k -> new RequestQueue() ); diff --git a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java similarity index 88% rename from src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java rename to src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java index 3193d2285..29fb14523 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/RequestPriority.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/RequestPriority.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; public enum RequestPriority { LOW, diff --git a/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java new file mode 100644 index 000000000..9eba5d1b5 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteRequest.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.ratelimit; + +import java.io.IOException; + +import org.opensearch.ad.model.AnomalyResult; +import org.opensearch.ad.ratelimit.ADResultWriteRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.forecast.model.ForecastResult; +import org.opensearch.forecast.ratelimit.ForecastResultWriteRequest; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class ResultWriteRequest extends QueuedRequest implements Writeable { + private final ResultType result; + // If resultIndex is null, result will be stored in default result index. + private final String resultIndex; + + public ResultWriteRequest(long expirationEpochMs, String detectorId, RequestPriority priority, ResultType result, String resultIndex) { + super(expirationEpochMs, detectorId, priority); + this.result = result; + this.resultIndex = resultIndex; + } + + /** + * + * @param subclass type + * @param result type + * @param expirationEpochMs expiration epoch in milliseconds + * @param configId config id + * @param priority request priority + * @param result result + * @param resultIndex result index + * @param clazz The clazz parameter is used to pass the class object of the desired subtype, which allows us to perform a dynamic cast to T and return the correctly-typed instance. + * @return + */ + public static , R extends IndexableResult> T create( + long expirationEpochMs, + String configId, + RequestPriority priority, + IndexableResult result, + String resultIndex, + Class clazz + ) { + if (result instanceof AnomalyResult) { + return clazz.cast(new ADResultWriteRequest(expirationEpochMs, configId, priority, (AnomalyResult) result, resultIndex)); + } else if (result instanceof ForecastResult) { + return clazz.cast(new ForecastResultWriteRequest(expirationEpochMs, configId, priority, (ForecastResult) result, resultIndex)); + } else { + throw new IllegalArgumentException("Unsupported result type"); + } + } + + public ResultWriteRequest(StreamInput in, Writeable.Reader resultReader) throws IOException { + this.result = resultReader.read(in); + this.resultIndex = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + result.writeTo(out); + out.writeOptionalString(resultIndex); + } + + public ResultType getResult() { + return result; + } + + public String getResultIndex() { + return resultIndex; + } +} diff --git a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java similarity index 57% rename from src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java index 2381e5db9..8bc5da450 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ResultWriteWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ResultWriteWorker.java @@ -1,19 +1,11 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.ratelimit; - -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY; +package org.opensearch.timeseries.ratelimit; +import java.io.IOException; import java.time.Clock; import java.time.Duration; import java.util.List; @@ -26,15 +18,8 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.DocWriteRequest; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedFunction; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -44,21 +29,33 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.threadpool.ThreadPool; - -public class ResultWriteWorker extends BatchWorker { +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.transport.ResultBulkRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.util.ExceptionUtil; + +public abstract class ResultWriteWorker, BatchRequestType extends ResultBulkRequest, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, ResultHandlerType extends IndexMemoryPressureAwareResultHandler> + extends BatchWorker { private static final Logger LOG = LogManager.getLogger(ResultWriteWorker.class); - public static final String WORKER_NAME = "result-write"; - - private final MultiEntityResultHandler resultHandler; - private NamedXContentRegistry xContentRegistry; + protected final ResultHandlerType resultHandler; + protected NamedXContentRegistry xContentRegistry; + private CheckedFunction resultParser; public ResultWriteWorker( - long heapSizeInBytes, - int singleRequestSizeInBytes, + String queueName, + long heapSize, + int singleRequestSize, Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -66,16 +63,20 @@ public ResultWriteWorker( float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, + Setting concurrencySetting, Duration executionTtl, - MultiEntityResultHandler resultHandler, + Setting batchSizeSetting, + Duration stateTtl, + NodeStateManager timeSeriesNodeStateManager, + ResultHandlerType resultHandler, NamedXContentRegistry xContentRegistry, - NodeStateManager stateManager, - Duration stateTtl + CheckedFunction resultParser, + AnalysisType context ) { super( - WORKER_NAME, - heapSizeInBytes, - singleRequestSizeInBytes, + queueName, + heapSize, + singleRequestSize, maxHeapPercentForQueueSetting, clusterService, random, @@ -87,18 +88,20 @@ public ResultWriteWorker( mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, - AD_RESULT_WRITE_QUEUE_CONCURRENCY, + concurrencySetting, executionTtl, - AD_RESULT_WRITE_QUEUE_BATCH_SIZE, + batchSizeSetting, stateTtl, - stateManager + timeSeriesNodeStateManager, + context ); this.resultHandler = resultHandler; this.xContentRegistry = xContentRegistry; + this.resultParser = resultParser; } @Override - protected void executeBatchRequest(ADResultBulkRequest request, ActionListener listener) { + protected void executeBatchRequest(BatchRequestType request, ActionListener listener) { if (request.numberOfActions() < 1) { listener.onResponse(null); return; @@ -107,19 +110,7 @@ protected void executeBatchRequest(ADResultBulkRequest request, ActionListener toProcess) { - final ADResultBulkRequest bulkRequest = new ADResultBulkRequest(); - for (ResultWriteRequest request : toProcess) { - bulkRequest.add(request); - } - return bulkRequest; - } - - @Override - protected ActionListener getResponseListener( - List toProcess, - ADResultBulkRequest bulkRequest - ) { + protected ActionListener getResponseListener(List toProcess, BatchRequestType bulkRequest) { return ActionListener.wrap(adResultBulkResponse -> { if (adResultBulkResponse == null || false == adResultBulkResponse.getRetryRequests().isPresent()) { // all successful @@ -136,8 +127,8 @@ protected ActionListener getResponseListener( setCoolDownStart(); } - for (ResultWriteRequest request : toProcess) { - nodeStateManager.setException(request.getId(), exception); + for (ResultWriteRequestType request : toProcess) { + nodeStateManager.setException(request.getConfigId(), exception); } LOG.error("Fail to save results", exception); }); @@ -148,50 +139,18 @@ private void enqueueRetryRequestIteration(List requestToRetry, int return; } DocWriteRequest currentRequest = requestToRetry.get(index); - Optional resultToRetry = getAnomalyResult(currentRequest); + Optional resultToRetry = getAnomalyResult(currentRequest); if (false == resultToRetry.isPresent()) { enqueueRetryRequestIteration(requestToRetry, index + 1); return; } - AnomalyResult result = resultToRetry.get(); - String detectorId = result.getConfigId(); - nodeStateManager.getAnomalyDetector(detectorId, onGetDetector(requestToRetry, index, detectorId, result)); - } - - private ActionListener> onGetDetector( - List requestToRetry, - int index, - String detectorId, - AnomalyResult resultToRetry - ) { - return ActionListener.wrap(detectorOptional -> { - if (false == detectorOptional.isPresent()) { - LOG.warn(new ParameterizedMessage("AnomalyDetector [{}] is not available.", detectorId)); - enqueueRetryRequestIteration(requestToRetry, index + 1); - return; - } - - AnomalyDetector detector = detectorOptional.get(); - super.put( - new ResultWriteRequest( - // expire based on execute start time - resultToRetry.getExecutionStartTime().toEpochMilli() + detector.getIntervalInMilliseconds(), - detectorId, - resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, - resultToRetry, - detector.getCustomResultIndex() - ) - ); - - enqueueRetryRequestIteration(requestToRetry, index + 1); - }, exception -> { - LOG.error(new ParameterizedMessage("fail to get detector [{}]", detectorId), exception); - enqueueRetryRequestIteration(requestToRetry, index + 1); - }); + ResultType result = resultToRetry.get(); + String id = result.getConfigId(); + nodeStateManager.getConfig(id, context, onGetDetector(requestToRetry, index, id, result)); } - private Optional getAnomalyResult(DocWriteRequest request) { + protected Optional getAnomalyResult(DocWriteRequest request) { try { if (false == (request instanceof IndexRequest)) { LOG.error(new ParameterizedMessage("We should only send IndexRquest, but get [{}].", request)); @@ -209,11 +168,52 @@ private Optional getAnomalyResult(DocWriteRequest request) { // org.opensearch.core.common.ParsingException: Failed to parse object: expecting token of type [START_OBJECT] but found // [null] xContentParser.nextToken(); - return Optional.of(AnomalyResult.parse(xContentParser)); + return Optional.of(resultParser.apply(xContentParser)); } } catch (Exception e) { LOG.error(new ParameterizedMessage("Fail to parse index request [{}]", request), e); } return Optional.empty(); } + + private ActionListener> onGetDetector( + List requestToRetry, + int index, + String id, + ResultType resultToRetry + ) { + return ActionListener.wrap(configOptional -> { + if (false == configOptional.isPresent()) { + LOG.warn(new ParameterizedMessage("Config [{}] is not available.", id)); + enqueueRetryRequestIteration(requestToRetry, index + 1); + return; + } + + Config config = configOptional.get(); + super.put( + createResultWriteRequest( + // expire based on execute start time + resultToRetry.getExecutionStartTime().toEpochMilli() + config.getIntervalInMilliseconds(), + id, + resultToRetry.isHighPriority() ? RequestPriority.HIGH : RequestPriority.MEDIUM, + resultToRetry, + config.getCustomResultIndex() + ) + ); + + enqueueRetryRequestIteration(requestToRetry, index + 1); + + }, exception -> { + LOG.error(new ParameterizedMessage("fail to get config [{}]", id), exception); + enqueueRetryRequestIteration(requestToRetry, index + 1); + }); + } + + protected abstract ResultWriteRequestType createResultWriteRequest( + long expirationEpochMs, + String configId, + RequestPriority priority, + ResultType result, + String resultIndex + ); } diff --git a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java similarity index 90% rename from src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java index 9d4891b7c..86a171d74 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/ScheduledWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/ScheduledWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -18,18 +18,20 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.breaker.CircuitBreakerService; public abstract class ScheduledWorker extends RateLimitedRequestWorker { - private static final Logger LOG = LogManager.getLogger(ColdEntityWorker.class); + private static final Logger LOG = LogManager.getLogger(ADColdEntityWorker.class); // the number of requests forwarded to the target queue protected volatile int batchSize; @@ -45,7 +47,7 @@ public ScheduledWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -55,7 +57,8 @@ public ScheduledWorker( int maintenanceFreqConstant, RateLimitedRequestWorker targetQueue, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( workerName, @@ -73,7 +76,8 @@ public ScheduledWorker( lowSegmentPruneRatio, maintenanceFreqConstant, stateTtl, - nodeStateManager + nodeStateManager, + context ); this.targetQueue = targetQueue; diff --git a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java similarity index 89% rename from src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java rename to src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java index 028a0643f..92d4f73d2 100644 --- a/src/main/java/org/opensearch/ad/ratelimit/SingleRequestWorker.java +++ b/src/main/java/org/opensearch/timeseries/ratelimit/SingleRequestWorker.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.ratelimit; +package org.opensearch.timeseries.ratelimit; import java.time.Clock; import java.time.Duration; @@ -20,12 +20,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.breaker.CircuitBreakerService; public abstract class SingleRequestWorker extends ConcurrentWorker { private static final Logger LOG = LogManager.getLogger(SingleRequestWorker.class); @@ -37,7 +38,7 @@ public SingleRequestWorker( Setting maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, - ADCircuitBreakerService adCircuitBreakerService, + CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, @@ -48,7 +49,8 @@ public SingleRequestWorker( Setting concurrencySetting, Duration executionTtl, Duration stateTtl, - NodeStateManager nodeStateManager + NodeStateManager nodeStateManager, + AnalysisType context ) { super( queueName, @@ -68,7 +70,8 @@ public SingleRequestWorker( concurrencySetting, executionTtl, stateTtl, - nodeStateManager + nodeStateManager, + context ); } diff --git a/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java new file mode 100644 index 000000000..f31e6ce0c --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/RestJobAction.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.timeseries.model.DateRange; + +import com.google.common.collect.ImmutableList; + +public abstract class RestJobAction extends BaseRestHandler { + protected DateRange parseInputDateRange(RestRequest request) throws IOException { + if (!request.hasContent()) { + return null; + } + XContentParser parser = request.contentParser(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + DateRange dateRange = DateRange.parse(parser); + return dateRange; + } + + @Override + public List routes() { + return ImmutableList.of(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java new file mode 100644 index 000000000..5f0a998dd --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/AbstractTimeSeriesActionHandler.java @@ -0,0 +1,838 @@ +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG; +import static org.opensearch.timeseries.util.ParseUtils.parseAggregators; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.isExceptionCausedByInvalidQuery; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.lang.StringUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsAction; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsRequest; +import org.opensearch.action.admin.indices.mapping.get.GetFieldMappingsResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.replication.ReplicationResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.search.aggregations.AggregatorFactories; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.MergeableList; +import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.model.ValidationIssueType; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.Sets; + +public abstract class AbstractTimeSeriesActionHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + + protected final Logger logger = LogManager.getLogger(AbstractTimeSeriesActionHandler.class); + + public static final String NAME_REGEX = "[a-zA-Z0-9._-]+"; + public static final Integer MAX_NAME_SIZE = 64; + public static final String CATEGORY_NOT_FOUND_ERR_MSG = "Can't find the categorical field %s"; + + public static String INVALID_NAME_SIZE = "Name should be shortened. The maximum limit is " + + AbstractTimeSeriesActionHandler.MAX_NAME_SIZE + + " characters."; + + public static final Set ALL_VALIDATION_ASPECTS_STRS = Arrays + .asList(ValidationAspect.values()) + .stream() + .map(aspect -> aspect.getName()) + .collect(Collectors.toSet()); + + protected final Config config; + protected final IndexManagement timeSeriesIndices; + protected final boolean isDryRun; + protected final Client client; + protected final String id; + protected final SecurityClientUtil clientUtil; + protected final User user; + protected final RestRequest.Method method; + protected final ConfigUpdateConfirmer handler = new ConfigUpdateConfirmer(); + protected final ClusterService clusterService; + protected final NamedXContentRegistry xContentRegistry; + protected final TransportService transportService; + protected final TimeValue requestTimeout; + protected final WriteRequest.RefreshPolicy refreshPolicy; + protected final Long seqNo; + protected final Long primaryTerm; + protected final String validationType; + protected final SearchFeatureDao searchFeatureDao; + protected final Integer maxFeatures; + protected final Integer maxCategoricalFields; + protected final AnalysisType context; + + public AbstractTimeSeriesActionHandler( + Config config, + IndexManagement timeSeriesIndices, + boolean isDryRun, + Client client, + String id, + SecurityClientUtil clientUtil, + User user, + RestRequest.Method method, + ClusterService clusterService, + NamedXContentRegistry xContentRegistry, + TransportService transportService, + TimeValue requestTimeout, + WriteRequest.RefreshPolicy refreshPolicy, + Long seqNo, + Long primaryTerm, + String validationType, + SearchFeatureDao searchFeatureDao, + Integer maxFeatures, + Integer maxCategoricalFields, + AnalysisType context + ) { + this.config = config; + this.timeSeriesIndices = timeSeriesIndices; + this.isDryRun = isDryRun; + this.client = client; + this.id = id == null ? "" : id; + this.clientUtil = clientUtil; + this.user = user; + this.method = method; + this.clusterService = clusterService; + this.xContentRegistry = xContentRegistry; + this.transportService = transportService; + this.requestTimeout = requestTimeout; + this.refreshPolicy = refreshPolicy; + this.seqNo = seqNo; + this.primaryTerm = primaryTerm; + this.validationType = validationType; + this.searchFeatureDao = searchFeatureDao; + this.maxFeatures = maxFeatures; + this.maxCategoricalFields = maxCategoricalFields; + this.context = context; + } + + /** + * Start function to process create/update/validate config request. + * + * If validation type is detector then all validation in this class involves validation checks + * against the configurations. + * Any issues raised here would block user from creating the config (e.g., anomaly detector). + * If validation Aspect is of type model then further non-blocker validation will be executed + * after the blocker validation is executed. Any issues that are raised for model validation + * are simply warnings for the user in terms of how configuration could be changed to lead to + * a higher likelihood of model training completing successfully. + * + * For custom index validation, if config is not using custom result index, check if config + * index exist first, if not, will create first. Otherwise, check if custom + * result index exists or not. If exists, will check if index mapping matches + * config result index mapping and if user has correct permission to write index. + * If doesn't exist, will create custom result index with result index + * mapping. + */ + public void start(ActionListener listener) { + String resultIndex = config.getCustomResultIndex(); + // use default detector result index which is system index + if (resultIndex == null) { + createOrUpdateConfig(listener); + return; + } + + if (this.isDryRun) { + if (timeSeriesIndices.doesIndexExist(resultIndex)) { + timeSeriesIndices + .validateCustomResultIndexAndExecute( + resultIndex, + () -> createOrUpdateConfig(listener), + ActionListener.wrap(r -> createOrUpdateConfig(listener), ex -> { + logger.error(ex); + listener.onFailure(createValidationException(ex.getMessage(), ValidationIssueType.RESULT_INDEX)); + return; + }) + ); + return; + } else { + createOrUpdateConfig(listener); + return; + } + } + // use custom result index if not validating and resultIndex not null + timeSeriesIndices.initCustomResultIndexAndExecute(resultIndex, () -> createOrUpdateConfig(listener), listener); + } + + // if isDryRun is true then this method is being executed through Validation API meaning actual + // index won't be created, only validation checks will be executed throughout the class + private void createOrUpdateConfig(ActionListener listener) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (!timeSeriesIndices.doesConfigIndexExist() && !this.isDryRun) { + logger.info("Config Indices do not exist"); + timeSeriesIndices + .initConfigIndex( + ActionListener + .wrap( + response -> onCreateMappingsResponse(response, false, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + logger.info("DryRun variable " + this.isDryRun); + validateName(this.isDryRun, listener); + } + } catch (Exception e) { + logger.error("Failed to create or update forecaster " + id, e); + listener.onFailure(e); + } + } + + protected void validateName(boolean indexingDryRun, ActionListener listener) { + if (!config.getName().matches(NAME_REGEX)) { + listener.onFailure(createValidationException(CommonMessages.INVALID_NAME, ValidationIssueType.NAME)); + return; + + } + if (config.getName().length() > MAX_NAME_SIZE) { + listener.onFailure(createValidationException(AbstractTimeSeriesActionHandler.INVALID_NAME_SIZE, ValidationIssueType.NAME)); + return; + } + validateTimeField(indexingDryRun, listener); + } + + protected void validateTimeField(boolean indexingDryRun, ActionListener listener) { + String givenTimeField = config.getTimeField(); + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(givenTimeField); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + // comments explaining fieldMappingResponse parsing can be found inside validateCategoricalField(String, boolean) + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + boolean foundField = false; + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.DATE_TYPE)) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.INVALID_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + } + } + } + } + } + } + if (!foundField) { + listener + .onFailure( + new ValidationException( + String.format(Locale.ROOT, CommonMessages.NON_EXISTENT_TIMESTAMP, givenTimeField), + ValidationIssueType.TIMEFIELD_FIELD, + ValidationAspect.FORECASTER + ) + ); + return; + } + prepareConfigIndexing(indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + /** + * Prepare for indexing a new config. + * @param indexingDryRun if this is dryrun for indexing; when validation, it is true; when create/update, it is false + */ + protected void prepareConfigIndexing(boolean indexingDryRun, ActionListener listener) { + if (method == RestRequest.Method.PUT) { + handler + .confirmJobRunning( + clusterService, + client, + id, + listener, + () -> updateConfig(id, indexingDryRun, listener), + xContentRegistry + ); + } else { + createConfig(indexingDryRun, listener); + } + } + + protected void updateConfig(String id, boolean indexingDryRun, ActionListener listener) { + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX, id); + client + .get( + request, + ActionListener + .wrap( + response -> onGetConfigResponse(response, indexingDryRun, id, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetConfigResponse(GetResponse response, boolean indexingDryRun, String id, ActionListener listener) { + if (!response.isExists()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + id, RestStatus.NOT_FOUND)); + return; + } + try (XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Config existingConfig = parse(parser, response); + // If category field changed, frontend may not be able to render AD result for different config types correctly. + // For example, if an anomaly detector changed from HC to single entity detector, AD result page may show multiple anomaly + // result points on the same time point if there are multiple entities have anomaly results. + // If single-category HC changed category field from IP to error type, the AD result page may show both IP and error type + // in top N entities list. That's confusing. + // So we decide to block updating detector category field. + if (!ParseUtils.listEqualsWithoutConsideringOrder(existingConfig.getCategoryFields(), config.getCategoryFields())) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CATEGORY_FIELD, RestStatus.BAD_REQUEST)); + return; + } + if (!Objects.equals(existingConfig.getCustomResultIndex(), config.getCustomResultIndex())) { + listener + .onFailure(new OpenSearchStatusException(CommonMessages.CAN_NOT_CHANGE_CUSTOM_RESULT_INDEX, RestStatus.BAD_REQUEST)); + return; + } + + ActionListener confirmHistoricalRunningListener = ActionListener + .wrap( + r -> searchConfigInputIndices(id, indexingDryRun, listener), + // can't update detector if there is AD task running + listener::onFailure + ); + + confirmHistoricalRunning(id, confirmHistoricalRunningListener); + } catch (IOException e) { + String message = "Failed to parse anomaly detector " + id; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + + } + + protected void validateAgainstExistingHCConfig(String detectorId, boolean indexingDryRun, ActionListener listener) { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.existsQuery(Config.CATEGORY_FIELD)); + + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchHCConfigResponse(response, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + + } + + protected void createConfig(boolean indexingDryRun, ActionListener listener) { + try { + List categoricalFields = config.getCategoryFields(); + if (categoricalFields != null && categoricalFields.size() > 0) { + validateAgainstExistingHCConfig(null, indexingDryRun, listener); + } else { + if (timeSeriesIndices.doesConfigIndexExist()) { + QueryBuilder query = QueryBuilders.matchAllQuery(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query).size(0).timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + + client + .search( + searchRequest, + ActionListener + .wrap( + response -> onSearchSingleStreamConfigResponse(response, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ) + ); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + protected void onSearchSingleStreamConfigResponse(SearchResponse response, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxSingleStreamConfigs()) { + String errorMsgSingleEntity = getExceedMaxSingleStreamConfigsErrorMsg(getMaxSingleStreamConfigs()); + logger.error(errorMsgSingleEntity); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsgSingleEntity, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsgSingleEntity)); + } else { + searchConfigInputIndices(null, indexingDryRun, listener); + } + } + + protected void onSearchHCConfigResponse(SearchResponse response, String detectorId, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.getHits().getTotalHits().value >= getMaxHCConfigs()) { + String errorMsg = getExceedMaxHCConfigsErrorMsg(getMaxHCConfigs()); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.GENERAL_SETTINGS)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateCategoricalField(detectorId, indexingDryRun, listener); + } + } + + @SuppressWarnings("unchecked") + protected void validateCategoricalField(String detectorId, boolean indexingDryRun, ActionListener listener) { + List categoryField = config.getCategoryFields(); + + if (categoryField == null) { + searchConfigInputIndices(detectorId, indexingDryRun, listener); + return; + } + + // we only support a certain number of categorical field + // If there is more fields than required, Config's constructor + // throws validation exception before reaching this line + int maxCategoryFields = maxCategoricalFields; + if (categoryField.size() > maxCategoryFields) { + listener + .onFailure( + createValidationException(CommonMessages.getTooManyCategoricalFieldErr(maxCategoryFields), ValidationIssueType.CATEGORY) + ); + return; + } + + String categoryField0 = categoryField.get(0); + + GetFieldMappingsRequest getMappingsRequest = new GetFieldMappingsRequest(); + getMappingsRequest.indices(config.getIndices().toArray(new String[0])).fields(categoryField.toArray(new String[0])); + getMappingsRequest.indicesOptions(IndicesOptions.strictExpand()); + + ActionListener mappingsListener = ActionListener.wrap(getMappingsResponse -> { + // example getMappingsResponse: + // GetFieldMappingsResponse{mappings={server-metrics={_doc={service=FieldMappingMetadata{fullName='service', + // source=org.opensearch.core.common.bytes.BytesArray@7ba87dbd}}}}} + // for nested field, it would be + // GetFieldMappingsResponse{mappings={server-metrics={_doc={host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08}}}}} + boolean foundField = false; + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + Map> mappingsByIndex = getMappingsResponse.mappings(); + + for (Map mappingsByField : mappingsByIndex.values()) { + for (Map.Entry field2Metadata : mappingsByField.entrySet()) { + // example output: + // host_nest.host2=FieldMappingMetadata{fullName='host_nest.host2', + // source=org.opensearch.core.common.bytes.BytesArray@8fb4de08} + + // Review why the change from FieldMappingMetadata to GetFieldMappingsResponse.FieldMappingMetadata + + GetFieldMappingsResponse.FieldMappingMetadata fieldMetadata = field2Metadata.getValue(); + + if (fieldMetadata != null) { + // sourceAsMap returns sth like {host2={type=keyword}} with host2 being a nested field + Map fieldMap = fieldMetadata.sourceAsMap(); + if (fieldMap != null) { + for (Object type : fieldMap.values()) { + if (type != null && type instanceof Map) { + foundField = true; + Map metadataMap = (Map) type; + String typeName = (String) metadataMap.get(CommonName.TYPE); + if (!typeName.equals(CommonName.KEYWORD_TYPE) && !typeName.equals(CommonName.IP_TYPE)) { + listener + .onFailure( + createValidationException(CATEGORICAL_FIELD_TYPE_ERR_MSG, ValidationIssueType.CATEGORY) + ); + return; + } + } + } + } + + } + } + } + + if (foundField == false) { + listener + .onFailure( + createValidationException( + String.format(Locale.ROOT, CATEGORY_NOT_FOUND_ERR_MSG, categoryField0), + ValidationIssueType.CATEGORY + ) + ); + return; + } + + searchConfigInputIndices(detectorId, indexingDryRun, listener); + }, error -> { + String message = String.format(Locale.ROOT, "Fail to get the index mapping of %s", config.getIndices()); + logger.error(message, error); + listener.onFailure(new IllegalArgumentException(message)); + }); + + clientUtil + .executeWithInjectedSecurity(GetFieldMappingsAction.INSTANCE, getMappingsRequest, user, client, context, mappingsListener); + } + + protected void searchConfigInputIndices(String detectorId, boolean indexingDryRun, ActionListener listener) { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() + .query(QueryBuilders.matchAllQuery()) + .size(0) + .timeout(requestTimeout); + + SearchRequest searchRequest = new SearchRequest(config.getIndices().toArray(new String[0])).source(searchSourceBuilder); + + ActionListener searchResponseListener = ActionListener + .wrap( + searchResponse -> onSearchConfigInputIndicesResponse(searchResponse, detectorId, indexingDryRun, listener), + exception -> listener.onFailure(exception) + ); + + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + + protected void onSearchConfigInputIndicesResponse( + SearchResponse response, + String detectorId, + boolean indexingDryRun, + ActionListener listener + ) throws IOException { + if (response.getHits().getTotalHits().value == 0) { + String errorMsg = getNoDocsInUserIndexErrorMsg(Arrays.toString(config.getIndices().toArray(new String[0]))); + logger.error(errorMsg); + if (indexingDryRun) { + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.INDICES)); + return; + } + listener.onFailure(new IllegalArgumentException(errorMsg)); + } else { + validateConfigFeatures(detectorId, indexingDryRun, listener); + } + } + + protected void checkConfigNameExists(String detectorId, boolean indexingDryRun, ActionListener listener) throws IOException { + if (timeSeriesIndices.doesConfigIndexExist()) { + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + // src/main/resources/mappings/config.json#L14 + boolQueryBuilder.must(QueryBuilders.termQuery("name.keyword", config.getName())); + if (StringUtils.isNotBlank(detectorId)) { + boolQueryBuilder.mustNot(QueryBuilders.termQuery(RestHandlerUtils._ID, detectorId)); + } + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(boolQueryBuilder).timeout(requestTimeout); + SearchRequest searchRequest = new SearchRequest(CommonName.CONFIG_INDEX).source(searchSourceBuilder); + client + .search( + searchRequest, + ActionListener + .wrap( + searchResponse -> onSearchConfigNameResponse( + searchResponse, + detectorId, + config.getName(), + indexingDryRun, + listener + ), + exception -> listener.onFailure(exception) + ) + ); + } else { + tryIndexingConfig(indexingDryRun, listener); + } + + } + + protected void onSearchConfigNameResponse( + SearchResponse response, + String detectorId, + String name, + boolean indexingDryRun, + ActionListener listener + ) throws IOException { + if (response.getHits().getTotalHits().value > 0) { + String errorMsg = getDuplicateConfigErrorMsg( + name, + Arrays.stream(response.getHits().getHits()).map(hit -> hit.getId()).collect(Collectors.toList()) + ); + logger.warn(errorMsg); + listener.onFailure(createValidationException(errorMsg, ValidationIssueType.NAME)); + } else { + tryIndexingConfig(indexingDryRun, listener); + } + } + + protected void tryIndexingConfig(boolean indexingDryRun, ActionListener listener) throws IOException { + if (!indexingDryRun) { + indexConfig(id, listener); + } else { + finishConfigValidationOrContinueToModelValidation(listener); + } + } + + protected Set getValidationTypes(String validationType) { + if (StringUtils.isBlank(validationType)) { + return getDefaultValidationType(); + } else { + Set typesInRequest = new HashSet<>(Arrays.asList(validationType.split(","))); + return ValidationAspect + .getNames(Sets.intersection(AbstractTimeSeriesActionHandler.ALL_VALIDATION_ASPECTS_STRS, typesInRequest)); + } + } + + protected void finishConfigValidationOrContinueToModelValidation(ActionListener listener) { + logger.info("Skipping indexing detector. No blocking issue found so far."); + if (!getValidationTypes(validationType).contains(ValidationAspect.MODEL)) { + listener.onResponse(null); + } else { + validateModel(listener); + } + } + + @SuppressWarnings("unchecked") + protected void indexConfig(String id, ActionListener listener) throws IOException { + Config copiedConfig = copyConfig(user, config); + IndexRequest indexRequest = new IndexRequest(CommonName.CONFIG_INDEX) + .setRefreshPolicy(refreshPolicy) + .source(copiedConfig.toXContent(XContentFactory.jsonBuilder(), XCONTENT_WITH_TYPE)) + .setIfSeqNo(seqNo) + .setIfPrimaryTerm(primaryTerm) + .timeout(requestTimeout); + if (StringUtils.isNotBlank(id)) { + indexRequest.id(id); + } + + client.index(indexRequest, new ActionListener() { + @Override + public void onResponse(IndexResponse indexResponse) { + String errorMsg = checkShardsFailure(indexResponse); + if (errorMsg != null) { + listener.onFailure(new OpenSearchStatusException(errorMsg, indexResponse.status())); + return; + } + listener.onResponse(createIndexConfigResponse(indexResponse, copiedConfig)); + } + + @Override + public void onFailure(Exception e) { + logger.warn("Failed to update config", e); + if (e.getMessage() != null && e.getMessage().contains("version conflict")) { + listener.onFailure(new IllegalArgumentException("There was a problem updating the config:[" + id + "]")); + } else { + listener.onFailure(e); + } + } + }); + } + + protected void onCreateMappingsResponse(CreateIndexResponse response, boolean indexingDryRun, ActionListener listener) + throws IOException { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + prepareConfigIndexing(indexingDryRun, listener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + listener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + "with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + } + + protected String checkShardsFailure(IndexResponse response) { + StringBuilder failureReasons = new StringBuilder(); + if (response.getShardInfo().getFailed() > 0) { + for (ReplicationResponse.ShardInfo.Failure failure : response.getShardInfo().getFailures()) { + failureReasons.append(failure); + } + return failureReasons.toString(); + } + return null; + } + + /** + * Validate config/syntax, and runtime error of config features + * @param id config id + * @param indexingDryRun if false, then will eventually index detector; true, skip indexing detector + * @throws IOException when fail to parse feature aggregation + */ + // TODO: move this method to util class so that it can be re-usable for more use cases + // https://github.com/opensearch-project/anomaly-detection/issues/39 + protected void validateConfigFeatures(String id, boolean indexingDryRun, ActionListener listener) throws IOException { + if (config != null && (config.getFeatureAttributes() == null || config.getFeatureAttributes().isEmpty())) { + checkConfigNameExists(id, indexingDryRun, listener); + return; + } + // checking configuration/syntax error of detector features + String error = RestHandlerUtils.checkFeaturesSyntax(config, maxFeatures); + if (StringUtils.isNotBlank(error)) { + if (indexingDryRun) { + listener.onFailure(createValidationException(error, ValidationIssueType.FEATURE_ATTRIBUTES)); + return; + } + listener.onFailure(new OpenSearchStatusException(error, RestStatus.BAD_REQUEST)); + return; + } + // checking runtime error from feature query + ActionListener>> validateFeatureQueriesListener = ActionListener + .wrap( + response -> { checkConfigNameExists(id, indexingDryRun, listener); }, + exception -> { + listener.onFailure(createValidationException(exception.getMessage(), ValidationIssueType.FEATURE_ATTRIBUTES)); + } + ); + MultiResponsesDelegateActionListener>> multiFeatureQueriesResponseListener = + new MultiResponsesDelegateActionListener>>( + validateFeatureQueriesListener, + config.getFeatureAttributes().size(), + getFeatureErrorMsg(config.getName()), + false + ); + + for (Feature feature : config.getFeatureAttributes()) { + SearchSourceBuilder ssb = new SearchSourceBuilder().size(1).query(QueryBuilders.matchAllQuery()); + AggregatorFactories.Builder internalAgg = parseAggregators( + feature.getAggregation().toString(), + xContentRegistry, + feature.getId() + ); + ssb.aggregation(internalAgg.getAggregatorFactories().iterator().next()); + SearchRequest searchRequest = new SearchRequest().indices(config.getIndices().toArray(new String[0])).source(ssb); + ActionListener searchResponseListener = ActionListener.wrap(response -> { + Optional aggFeatureResult = searchFeatureDao.parseResponse(response, Arrays.asList(feature.getId())); + if (aggFeatureResult.isPresent()) { + multiFeatureQueriesResponseListener + .onResponse( + new MergeableList>(new ArrayList>(Arrays.asList(aggFeatureResult))) + ); + } else { + String errorMessage = CommonMessages.FEATURE_WITH_EMPTY_DATA_MSG + feature.getName(); + logger.error(errorMessage); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + } + }, e -> { + String errorMessage; + if (isExceptionCausedByInvalidQuery(e)) { + errorMessage = CommonMessages.FEATURE_WITH_INVALID_QUERY_MSG + feature.getName(); + } else { + errorMessage = CommonMessages.UNKNOWN_SEARCH_QUERY_EXCEPTION_MSG + feature.getName(); + } + logger.error(errorMessage, e); + multiFeatureQueriesResponseListener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST, e)); + }); + clientUtil.asyncRequestWithInjectedSecurity(searchRequest, client::search, user, client, context, searchResponseListener); + } + } + + protected abstract TimeSeriesException createValidationException(String msg, ValidationIssueType type); + + protected abstract Config parse(XContentParser parser, GetResponse response) throws IOException; + + // have listener as a function parameter instead of instance fields so that we can create new listeners + // encapsulating topmost listener + protected abstract void confirmHistoricalRunning(String detectorId, ActionListener listener); + + protected abstract Integer getMaxSingleStreamConfigs(); + + protected abstract Integer getMaxHCConfigs(); + + protected abstract String getExceedMaxSingleStreamConfigsErrorMsg(int maxSingleStreamConfigs); + + protected abstract String getExceedMaxHCConfigsErrorMsg(int maxHCConfigs); + + protected abstract String getNoDocsInUserIndexErrorMsg(String suppliedIndices); + + protected abstract String getDuplicateConfigErrorMsg(String nane, List otherConfigId); + + protected abstract String getFeatureErrorMsg(String id); + + protected abstract Config copyConfig(User user, Config config); + + protected abstract T createIndexConfigResponse(IndexResponse indexResponse, Config config); + + protected abstract Set getDefaultValidationType(); + + protected abstract void validateModel(ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java similarity index 64% rename from src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java rename to src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java index f279f8b63..03d76918f 100644 --- a/src/main/java/org/opensearch/ad/rest/handler/AnomalyDetectorActionHandler.java +++ b/src/main/java/org/opensearch/timeseries/rest/handler/ConfigUpdateConfirmer.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.rest.handler; +package org.opensearch.timeseries.rest.handler; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; @@ -21,7 +21,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.rest.RestStatus; @@ -29,72 +28,73 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.RestHandlerUtils; /** - * Common handler to process AD request. + * Get job to make sure job has been stopped before updating a config. */ -public class AnomalyDetectorActionHandler { +public class ConfigUpdateConfirmer { - private final Logger logger = LogManager.getLogger(AnomalyDetectorActionHandler.class); + private final Logger logger = LogManager.getLogger(ConfigUpdateConfirmer.class); /** - * Get detector job for update/delete AD job. - * If AD job exist, will return error message; otherwise, execute function. + * Get job for update/delete config. + * If job exist, will return error message; otherwise, execute function. * - * @param clusterService ES cluster service - * @param client ES node client - * @param detectorId detector identifier + * @param clusterService OS cluster service + * @param client OS node client + * @param id job identifier * @param listener Listener to send response - * @param function AD function + * @param function time series function * @param xContentRegistry Registry which is used for XContentParser */ - public void getDetectorJob( + public void confirmJobRunning( ClusterService clusterService, Client client, - String detectorId, + String id, ActionListener listener, ExecutorFunction function, NamedXContentRegistry xContentRegistry ) { + // forecasting and ad share the same job index if (clusterService.state().metadata().indices().containsKey(CommonName.JOB_INDEX)) { - GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(detectorId); + GetRequest request = new GetRequest(CommonName.JOB_INDEX).id(id); client .get( request, - ActionListener - .wrap(response -> onGetAdJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { - logger.error("Fail to get anomaly detector job: " + detectorId, exception); - listener.onFailure(exception); - }) + ActionListener.wrap(response -> onGetJobResponseForWrite(response, listener, function, xContentRegistry), exception -> { + logger.error("Fail to get job: " + id, exception); + listener.onFailure(exception); + }) ); } else { function.execute(); } } - private void onGetAdJobResponseForWrite( + private void onGetJobResponseForWrite( GetResponse response, ActionListener listener, ExecutorFunction function, NamedXContentRegistry xContentRegistry ) { if (response.isExists()) { - String adJobId = response.getId(); - if (adJobId != null) { - // check if AD job is running on the detector, if yes, we can't delete the detector + String jobId = response.getId(); + if (jobId != null) { + // check if job is running, if yes, we can't delete the config try ( XContentParser parser = RestHandlerUtils .createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetectorJob adJob = AnomalyDetectorJob.parse(parser); + Job adJob = Job.parse(parser); if (adJob.isEnabled()) { - listener.onFailure(new OpenSearchStatusException("Detector job is running: " + adJobId, RestStatus.BAD_REQUEST)); + listener.onFailure(new OpenSearchStatusException("Job is running: " + jobId, RestStatus.BAD_REQUEST)); return; } } catch (IOException e) { - String message = "Failed to parse anomaly detector job " + adJobId; + String message = "Failed to parse job " + jobId; logger.error(message, e); listener.onFailure(new OpenSearchStatusException(message, RestStatus.BAD_REQUEST)); } diff --git a/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java new file mode 100644 index 000000000..fdd8ee014 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/rest/handler/IndexJobActionHandler.java @@ -0,0 +1,634 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.rest.handler; + +import static org.opensearch.action.DocWriteResponse.Result.CREATED; +import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionType; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.jobscheduler.spi.schedule.IntervalSchedule; +import org.opensearch.jobscheduler.spi.schedule.Schedule; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.ResultRequest; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.base.Throwables; + +/** + * job REST action handler to process POST/PUT request. + */ +public abstract class IndexJobActionHandler< + IndexType extends Enum & TimeSeriesIndex, + IndexManagementType extends IndexManagement, + TaskCacheManagerType extends TaskCacheManager, + TaskTypeEnum extends TaskType, + TaskClass extends TimeSeriesTask, + TaskManagerType extends TaskManager, + IndexableResultType extends IndexableResult, + ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder< + IndexType, IndexManagementType, TaskCacheManagerType, TaskTypeEnum, TaskClass, TaskManagerType, IndexableResultType + > + > { + + private final IndexManagementType indexManagement; + private final Client client; + private final NamedXContentRegistry xContentRegistry; + private final TaskManagerType taskManager; + + private final Logger logger = LogManager.getLogger(IndexJobActionHandler.class); + private final TimeValue requestTimeout; + private final ExecuteResultResponseRecorderType recorder; + private final ActionType> resultAction; + private final AnalysisType analysisType; + private final String stateIndex; + private final ActionType stopConfigAction; + private final NodeStateManager nodeStateManager; + + /** + * Constructor function. + * + * @param client ES node client that executes actions on the local node + * @param indexManagement index manager + * @param requestTimeout request time out configuration + * @param xContentRegistry Registry which is used for XContentParser + * @param taskManager task manager + * @param recorder Utility to record AnomalyResultAction execution result + * @param resultAction result action + * @param analysisType analysis type + * @param stateIndex State index name + * @param stopConfigAction Stop config action + * @param nodeStateManager Node state manager + */ + public IndexJobActionHandler( + Client client, + IndexManagementType indexManagement, + NamedXContentRegistry xContentRegistry, + TaskManagerType taskManager, + ExecuteResultResponseRecorderType recorder, + ActionType> resultAction, + AnalysisType analysisType, + String stateIndex, + ActionType stopConfigAction, + NodeStateManager nodeStateManager, + Settings settings, + Setting timeoutSetting + ) { + this.client = client; + this.indexManagement = indexManagement; + this.xContentRegistry = xContentRegistry; + this.taskManager = taskManager; + this.recorder = recorder; + this.resultAction = resultAction; + this.analysisType = analysisType; + this.stateIndex = stateIndex; + this.stopConfigAction = stopConfigAction; + this.nodeStateManager = nodeStateManager; + this.requestTimeout = timeoutSetting.get(settings); + } + + /** + * Start job. + * 1. If job doesn't exist, create new job. + * 2. If job exists: a). if job enabled, return error message; b). if job disabled, enable job. + * @param config config accessor + * @param listener Listener to send responses + */ + public void startJob(Config config, TransportService transportService, ActionListener listener) { + // this start listener is created & injected throughout the job handler so that whenever the job response is received, + // there's the extra step of trying to index results and update detector state with a 60s delay. + ActionListener startListener = ActionListener.wrap(r -> { + try { + Instant executionEndTime = Instant.now(); + IntervalTimeConfiguration schedule = (IntervalTimeConfiguration) config.getInterval(); + Instant executionStartTime = executionEndTime.minus(schedule.getInterval(), schedule.getUnit()); + ResultRequest getRequest = createResultRequest( + config.getId(), + executionStartTime.toEpochMilli(), + executionEndTime.toEpochMilli() + ); + client + .execute( + resultAction, + getRequest, + ActionListener + .wrap(response -> recorder.indexResult(executionStartTime, executionEndTime, response, config), exception -> { + + recorder + .indexResultException( + executionStartTime, + executionEndTime, + Throwables.getStackTraceAsString(exception), + null, + config + ); + }) + ); + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + listener.onResponse(r); + + }, listener::onFailure); + if (!indexManagement.doesJobIndexExist()) { + indexManagement.initJobIndex(ActionListener.wrap(response -> { + if (response.isAcknowledged()) { + logger.info("Created {} with mappings.", CommonName.CONFIG_INDEX); + createJob(config, transportService, startListener); + } else { + logger.warn("Created {} with mappings call not acknowledged.", CommonName.CONFIG_INDEX); + startListener + .onFailure( + new OpenSearchStatusException( + "Created " + CommonName.CONFIG_INDEX + " with mappings call not acknowledged.", + RestStatus.INTERNAL_SERVER_ERROR + ) + ); + } + }, exception -> startListener.onFailure(exception))); + } else { + createJob(config, transportService, startListener); + } + } + + private void createJob(Config config, TransportService transportService, ActionListener listener) { + try { + IntervalTimeConfiguration interval = (IntervalTimeConfiguration) config.getInterval(); + Schedule schedule = new IntervalSchedule(Instant.now(), (int) interval.getInterval(), interval.getUnit()); + Duration duration = Duration.of(interval.getInterval(), interval.getUnit()); + + Job job = new Job( + config.getId(), + schedule, + config.getWindowDelay(), + true, + Instant.now(), + null, + Instant.now(), + duration.getSeconds(), + config.getUser(), + config.getCustomResultIndex(), + analysisType + ); + + getJobForWrite(config, job, transportService, listener); + } catch (Exception e) { + String message = "Failed to parse job " + config.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + + private void getJobForWrite(Config config, Job job, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(config.getId()); + + client + .get( + getRequest, + ActionListener + .wrap( + response -> onGetJobForWrite(response, config, job, transportService, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onGetJobForWrite( + GetResponse response, + Config config, + Job job, + TransportService transportService, + ActionListener listener + ) throws IOException { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job currentAdJob = Job.parse(parser); + if (currentAdJob.isEnabled()) { + listener + .onFailure( + new OpenSearchStatusException("Anomaly detector job is already running: " + config.getId(), RestStatus.OK) + ); + return; + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + job.isEnabled(), + Instant.now(), + currentAdJob.getDisabledTime(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + // Get latest realtime task and check its state before index job. Will reset running realtime task + // as STOPPED first if job disabled, then start new job and create new realtime task. + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener + .wrap( + r -> { indexJob(newJob, null, listener); }, + e -> { + // Have logged error message in ADTaskManager#startDetector + listener.onFailure(e); + } + ) + ); + } + } catch (IOException e) { + String message = "Failed to parse anomaly detector job " + job.getName(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + startConfig( + config, + null, + job.getUser(), + transportService, + ActionListener.wrap(r -> { indexJob(job, null, listener); }, e -> listener.onFailure(e)) + ); + } + } + + /** + * Start config. + * For historical analysis, this method will be called on coordinating node. + * For realtime task, we won't know AD job coordinating node until AD job starts. So + * this method will be called on vanilla node. + * + * Will init task index if not exist and write new AD task to index. If task index + * exists, will check if there is task running. If no running task, reset old task + * as not latest and clean old tasks which exceeds max old task doc limitation. + * Then find out node with least load and dispatch task to that node(worker node). + * + * @param config anomaly detector + * @param detectionDateRange detection date range + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void startConfig( + Config config, + DateRange detectionDateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + try { + if (indexManagement.doesStateIndexExist()) { + // If state index exist, check if latest AD task is running + taskManager.getAndExecuteOnLatestConfigLevelTask(config, detectionDateRange, user, transportService, listener); + } else { + // If state index doesn't exist, create index and execute detector. + indexManagement.initStateIndex(ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + logger.info("Created {} with mappings.", stateIndex); + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, detectionDateRange, user, listener); + } else { + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, stateIndex); + logger.warn(error); + listener.onFailure(new OpenSearchStatusException(error, RestStatus.INTERNAL_SERVER_ERROR)); + } + }, e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceAlreadyExistsException) { + taskManager.updateLatestFlagOfOldTasksAndCreateNewTask(config, detectionDateRange, user, listener); + } else { + logger.error("Failed to init anomaly detection state index", e); + listener.onFailure(e); + } + })); + } + } catch (Exception e) { + logger.error("Failed to start detector " + config.getId(), e); + listener.onFailure(e); + } + } + + private void indexJob(Job job, ExecutorFunction function, ActionListener listener) throws IOException { + IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(job.toXContent(XContentFactory.jsonBuilder(), RestHandlerUtils.XCONTENT_WITH_TYPE)) + .timeout(requestTimeout) + .id(job.getName()); + client + .index( + indexRequest, + ActionListener + .wrap( + response -> onIndexAnomalyDetectorJobResponse(response, function, listener), + exception -> listener.onFailure(exception) + ) + ); + } + + private void onIndexAnomalyDetectorJobResponse( + IndexResponse response, + ExecutorFunction function, + ActionListener listener + ) { + if (response == null || (response.getResult() != CREATED && response.getResult() != UPDATED)) { + String errorMsg = ExceptionUtil.getShardsFailure(response); + listener.onFailure(new OpenSearchStatusException(errorMsg, response.status())); + return; + } + if (function != null) { + function.execute(); + } else { + JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); + listener.onResponse(anomalyDetectorJobResponse); + } + } + + /** + * Stop config job. + * 1.If job not exists, return error message + * 2.If job exists: a).if job state is disabled, return error message; b).if job state is enabled, disable job. + * + * @param configId config identifier + * @param listener Listener to send responses + */ + public void stopJob(String configId, TransportService transportService, ActionListener listener) { + GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + + client.get(getRequest, ActionListener.wrap(response -> { + if (response.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, transportService, listener); + } else { + Job newJob = new Job( + job.getName(), + job.getSchedule(), + job.getWindowDelay(), + false, // disable job + job.getEnabledTime(), + Instant.now(), + Instant.now(), + job.getLockDurationSeconds(), + job.getUser(), + job.getCustomResultIndex(), + job.getAnalysisType() + ); + indexJob( + newJob, + () -> client + .execute( + stopConfigAction, + new StopConfigRequest(configId), + stopConfigListener(configId, transportService, listener) + ), + listener + ); + } + } catch (IOException e) { + String message = "Failed to parse job " + configId; + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } else { + listener.onFailure(new OpenSearchStatusException("job not exist: " + configId, RestStatus.BAD_REQUEST)); + } + }, exception -> listener.onFailure(exception))); + } + + private ActionListener stopConfigListener( + String configId, + TransportService transportService, + ActionListener listener + ) { + return new ActionListener() { + @Override + public void onResponse(StopConfigResponse stopDetectorResponse) { + if (stopDetectorResponse.success()) { + logger.info("model deleted successfully for config {}", configId); + // e.g., StopDetectorTransportAction will send out DeleteModelAction which will clear all realtime cache. + // Pass null transport service to method "stopLatestRealtimeTask" to not re-clear coordinating node cache. + taskManager.stopLatestRealtimeTask(configId, TaskState.STOPPED, null, null, listener); + } else { + logger.error("Failed to delete model for config {}", configId); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to delete model", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete model for config " + configId, e); + // If failed to clear all realtime cache, will try to re-clear coordinating node cache. + taskManager + .stopLatestRealtimeTask( + configId, + TaskState.FAILED, + new OpenSearchStatusException("Failed to execute stop config action", RestStatus.INTERNAL_SERVER_ERROR), + transportService, + listener + ); + } + }; + } + + /** + * Start detector. Will create schedule job for realtime analysis, + * and start task for historical detector. + * + * @param configId config id + * @param dateRange historical analysis date range + * @param handler anomaly detector job action handler + * @param user user + * @param transportService transport service + * @param context thread context + * @param listener action listener + */ + public void startConfig( + String configId, + DateRange dateRange, + User user, + TransportService transportService, + ThreadContext.StoredContext context, + ActionListener listener + ) { + // upgrade index mapping + indexManagement.update(); + + nodeStateManager.getConfig(configId, analysisType, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + + // Validate if config is ready to start. Will return null if ready to start. + String errorMessage = validateConfig(config.get()); + if (errorMessage != null) { + listener.onFailure(new OpenSearchStatusException(errorMessage, RestStatus.BAD_REQUEST)); + return; + } + String resultIndex = config.get().getCustomResultIndex(); + if (resultIndex == null) { + startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config); + return; + } + context.restore(); + indexManagement + .initCustomResultIndexAndExecute( + resultIndex, + () -> startRealtimeOrHistoricalAnalysis(dateRange, user, transportService, listener, config), + listener + ); + + }, listener); + } + + private String validateConfig(Config detector) { + String error = null; + if (detector.getFeatureAttributes().size() == 0) { + error = "Can't start job as no features configured"; + } else if (detector.getEnabledFeatureIds().size() == 0) { + error = "Can't start job as no enabled features configured"; + } + return error; + } + + private void startRealtimeOrHistoricalAnalysis( + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener, + Optional config + ) { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + if (dateRange == null) { + // start realtime job + startJob(config.get(), transportService, listener); + } else { + // start historical analysis task + taskManager.startHistorical(config.get(), dateRange, user, transportService, listener); + } + } catch (Exception e) { + logger.error("Failed to stash context", e); + listener.onFailure(e); + } + } + + /** + * Stop config. + * For realtime config, will set job as disabled. + * For historical config, will set its task as cancelled. + * + * @param configId config id + * @param historical stop historical analysis or not + * @param user user + * @param transportService transport service + * @param listener action listener + */ + public void stopConfig( + String configId, + boolean historical, + User user, + TransportService transportService, + ActionListener listener + ) { + nodeStateManager.getConfig(configId, analysisType, (config) -> { + if (!config.isPresent()) { + listener.onFailure(new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND)); + return; + } + if (historical) { + // stop historical analyis + taskManager + .getAndExecuteOnLatestConfigLevelTask( + configId, + getHistorialConfigTaskTypes(), + (task) -> taskManager.stopHistoricalAnalysis(configId, task, user, listener), + transportService, + false,// don't need to reset task state when stop config + listener + ); + } else { + // stop realtime detector job + stopJob(configId, transportService, listener); + } + }, listener); + } + + protected abstract ResultRequest createResultRequest(String configID, long start, long end); + + protected abstract List getHistorialConfigTaskTypes(); +} diff --git a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java index a9aebff53..8ce4cbf9b 100644 --- a/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java +++ b/src/main/java/org/opensearch/timeseries/settings/TimeSeriesSettings.java @@ -21,9 +21,9 @@ public class TimeSeriesSettings { // the larger shingle size, the harder to fill in a complete shingle public static final int MAX_SHINGLE_SIZE = 60; - public static final String CONFIG_INDEX_MAPPING_FILE = "mappings/anomaly-detectors.json"; + public static final String CONFIG_INDEX_MAPPING_FILE = "mappings/config.json"; - public static final String JOBS_INDEX_MAPPING_FILE = "mappings/anomaly-detector-jobs.json"; + public static final String JOBS_INDEX_MAPPING_FILE = "mappings/job.json"; // 100,000 insertions costs roughly 1KB. public static final int DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION = 100_000; @@ -52,6 +52,10 @@ public class TimeSeriesSettings { public static final Duration HOURLY_MAINTENANCE = Duration.ofHours(1); + // Maximum number of deleted tasks can keep in cache. + public static final Setting MAX_CACHED_DELETED_TASKS = Setting + .intSetting("plugins.timeseries.max_cached_deleted_tasks", 1000, 1, 10_000, Setting.Property.NodeScope, Setting.Property.Dynamic); + // ====================================== // Checkpoint setting // ====================================== @@ -185,7 +189,12 @@ public class TimeSeriesSettings { ); // ====================================== - // AD Index setting + // Index setting // ====================================== public static int MAX_UPDATE_RETRY_TIMES = 10_000; + + // ====================================== + // JOB + // ====================================== + public static final long DEFAULT_JOB_LOC_DURATION_SECONDS = 60; } diff --git a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/InternalStatNames.java rename to src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java index 56ff012a5..356a7828d 100644 --- a/src/main/java/org/opensearch/ad/stats/InternalStatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/InternalStatNames.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; /** * Enum containing names of all internal stats which will not be returned diff --git a/src/main/java/org/opensearch/timeseries/stats/StatNames.java b/src/main/java/org/opensearch/timeseries/stats/StatNames.java index a72e3f1b0..c63f4aac6 100644 --- a/src/main/java/org/opensearch/timeseries/stats/StatNames.java +++ b/src/main/java/org/opensearch/timeseries/stats/StatNames.java @@ -19,30 +19,46 @@ * AD stats REST API. */ public enum StatNames { - AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count"), - AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count"), - AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count"), - AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count"), - DETECTOR_COUNT("detector_count"), - SINGLE_ENTITY_DETECTOR_COUNT("single_entity_detector_count"), - MULTI_ENTITY_DETECTOR_COUNT("multi_entity_detector_count"), - ANOMALY_DETECTORS_INDEX_STATUS("anomaly_detectors_index_status"), - ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status"), - MODELS_CHECKPOINT_INDEX_STATUS("models_checkpoint_index_status"), - ANOMALY_DETECTION_JOB_INDEX_STATUS("anomaly_detection_job_index_status"), - ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status"), - MODEL_INFORMATION("models"), - AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count"), - AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count"), - AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count"), - AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count"), - MODEL_COUNT("model_count"), - MODEL_CORRUTPION_COUNT("model_corruption_count"); + // common stats + CONFIG_INDEX_STATUS("config_index_status", StatType.TIMESERIES), + JOB_INDEX_STATUS("job_index_status", StatType.TIMESERIES), + // AD stats + AD_EXECUTE_REQUEST_COUNT("ad_execute_request_count", StatType.AD), + AD_EXECUTE_FAIL_COUNT("ad_execute_failure_count", StatType.AD), + AD_HC_EXECUTE_REQUEST_COUNT("ad_hc_execute_request_count", StatType.AD), + AD_HC_EXECUTE_FAIL_COUNT("ad_hc_execute_failure_count", StatType.AD), + DETECTOR_COUNT("detector_count", StatType.AD), + SINGLE_STREAM_DETECTOR_COUNT("single_stream_detector_count", StatType.AD), + HC_DETECTOR_COUNT("hc_detector_count", StatType.AD), + ANOMALY_RESULTS_INDEX_STATUS("anomaly_results_index_status", StatType.AD), + AD_MODELS_CHECKPOINT_INDEX_STATUS("anomaly_models_checkpoint_index_status", StatType.AD), + ANOMALY_DETECTION_STATE_STATUS("anomaly_detection_state_status", StatType.AD), + MODEL_INFORMATION("models", StatType.AD), + AD_EXECUTING_BATCH_TASK_COUNT("ad_executing_batch_task_count", StatType.AD), + AD_CANCELED_BATCH_TASK_COUNT("ad_canceled_batch_task_count", StatType.AD), + AD_TOTAL_BATCH_TASK_EXECUTION_COUNT("ad_total_batch_task_execution_count", StatType.AD), + AD_BATCH_TASK_FAILURE_COUNT("ad_batch_task_failure_count", StatType.AD), + MODEL_COUNT("model_count", StatType.AD), + AD_MODEL_CORRUTPION_COUNT("ad_model_corruption_count", StatType.AD), + // forecast stats + FORECAST_EXECUTE_REQUEST_COUNT("forecast_execute_request_count", StatType.FORECAST), + FORECAST_EXECUTE_FAIL_COUNT("forecast_execute_failure_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_REQUEST_COUNT("forecast_hc_execute_request_count", StatType.FORECAST), + FORECAST_HC_EXECUTE_FAIL_COUNT("forecast_hc_execute_failure_count", StatType.FORECAST), + FORECAST_RESULTS_INDEX_STATUS("forecast_results_index_status", StatType.FORECAST), + FORECAST_MODELS_CHECKPOINT_INDEX_STATUS("forecast_models_checkpoint_index_status", StatType.FORECAST), + FORECAST_STATE_STATUS("forecastn_state_status", StatType.FORECAST), + FORECASTER_COUNT("forecaster_count", StatType.FORECAST), + SINGLE_STREAM_FORECASTER_COUNT("single_stream_forecaster_count", StatType.FORECAST), + HC_FORECASTER_COUNT("hc_forecaster_count", StatType.FORECAST), + FORECAST_MODEL_CORRUTPION_COUNT("forecast_model_corruption_count", StatType.FORECAST),; - private String name; + private final String name; + private final StatType type; - StatNames(String name) { + StatNames(String name, StatType type) { this.name = name; + this.type = type; } /** @@ -54,6 +70,10 @@ public String getName() { return name; } + public StatType getType() { + return type; + } + /** * Get set of stat names * diff --git a/src/main/java/org/opensearch/timeseries/stats/StatType.java b/src/main/java/org/opensearch/timeseries/stats/StatType.java new file mode 100644 index 000000000..cca482bc7 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/StatType.java @@ -0,0 +1,18 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats; + +public enum StatType { + AD, + FORECAST, + TIMESERIES +} diff --git a/src/main/java/org/opensearch/ad/stats/ADStats.java b/src/main/java/org/opensearch/timeseries/stats/Stats.java similarity index 67% rename from src/main/java/org/opensearch/ad/stats/ADStats.java rename to src/main/java/org/opensearch/timeseries/stats/Stats.java index 1fb0e8fe4..f9b168392 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStats.java +++ b/src/main/java/org/opensearch/timeseries/stats/Stats.java @@ -9,24 +9,20 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; import java.util.HashMap; import java.util.Map; -/** - * This class is the main entry-point for access to the stats that the AD plugin keeps track of. - */ -public class ADStats { - - private Map> stats; +public class Stats { + private Map> stats; /** * Constructor * * @param stats Map of the stats that are to be kept */ - public ADStats(Map> stats) { + public Stats(Map> stats) { this.stats = stats; } @@ -35,7 +31,7 @@ public ADStats(Map> stats) { * * @return all of the stats */ - public Map> getStats() { + public Map> getStats() { return stats; } @@ -43,10 +39,10 @@ public Map> getStats() { * Get individual stat by stat name * * @param key Name of stat - * @return ADStat + * @return TimeSeriesStat * @throws IllegalArgumentException thrown on illegal statName */ - public ADStat getStat(String key) throws IllegalArgumentException { + public TimeSeriesStat getStat(String key) throws IllegalArgumentException { if (!stats.keySet().contains(key)) { throw new IllegalArgumentException("Stat=\"" + key + "\" does not exist"); } @@ -58,7 +54,7 @@ public ADStat getStat(String key) throws IllegalArgumentException { * * @return Map of stats kept at the node level */ - public Map> getNodeStats() { + public Map> getNodeStats() { return getClusterOrNodeStats(false); } @@ -67,14 +63,14 @@ public Map> getNodeStats() { * * @return Map of stats kept at the cluster level */ - public Map> getClusterStats() { + public Map> getClusterStats() { return getClusterOrNodeStats(true); } - private Map> getClusterOrNodeStats(Boolean getClusterStats) { - Map> statsMap = new HashMap<>(); + private Map> getClusterOrNodeStats(Boolean getClusterStats) { + Map> statsMap = new HashMap<>(); - for (Map.Entry> entry : stats.entrySet()) { + for (Map.Entry> entry : stats.entrySet()) { if (entry.getValue().isClusterLevel() == getClusterStats) { statsMap.put(entry.getKey(), entry.getValue()); } diff --git a/src/main/java/org/opensearch/ad/stats/ADStat.java b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java similarity index 86% rename from src/main/java/org/opensearch/ad/stats/ADStat.java rename to src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java index 531205907..e10ab9127 100644 --- a/src/main/java/org/opensearch/ad/stats/ADStat.java +++ b/src/main/java/org/opensearch/timeseries/stats/TimeSeriesStat.java @@ -9,17 +9,17 @@ * GitHub history for details. */ -package org.opensearch.ad.stats; +package org.opensearch.timeseries.stats; import java.util.function.Supplier; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; /** * Class represents a stat the plugin keeps track of */ -public class ADStat { +public class TimeSeriesStat { private Boolean clusterLevel; private Supplier supplier; @@ -29,7 +29,7 @@ public class ADStat { * @param clusterLevel whether the stat has clusterLevel scope or nodeLevel scope * @param supplier supplier that returns the stat's value */ - public ADStat(Boolean clusterLevel, Supplier supplier) { + public TimeSeriesStat(Boolean clusterLevel, Supplier supplier) { this.clusterLevel = clusterLevel; this.supplier = supplier; } diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java similarity index 95% rename from src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java index 39acd94ff..0953e9450 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/CounterSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/CounterSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.LongAdder; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java similarity index 92% rename from src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java index ab9177cb5..1da433108 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/IndexStatusSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/IndexStatusSupplier.java @@ -9,11 +9,11 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.function.Supplier; -import org.opensearch.ad.util.IndexUtils; +import org.opensearch.timeseries.util.IndexUtils; /** * IndexStatusSupplier provides the status of an index as the value diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeCountSupplier.java similarity index 50% rename from src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeCountSupplier.java index 8fdac74d7..f01305537 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeCountSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeCountSupplier.java @@ -9,20 +9,22 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.function.Supplier; import java.util.stream.Stream; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.forecast.caching.ForecastCacheProvider; /** * ModelsOnNodeCountSupplier provides the number of models a node contains */ public class ModelsOnNodeCountSupplier implements Supplier { - private ModelManager modelManager; - private CacheProvider cache; + private ADModelManager modelManager; + private ADCacheProvider adCache; + private ForecastCacheProvider forecastCache; /** * Constructor @@ -30,13 +32,19 @@ public class ModelsOnNodeCountSupplier implements Supplier { * @param modelManager object that manages the model partitions hosted on the node * @param cache object that manages multi-entity detectors' models */ - public ModelsOnNodeCountSupplier(ModelManager modelManager, CacheProvider cache) { + public ModelsOnNodeCountSupplier(ADModelManager modelManager, ADCacheProvider adCache, ForecastCacheProvider forecastCache) { this.modelManager = modelManager; - this.cache = cache; + this.adCache = adCache; + this.forecastCache = forecastCache; } @Override public Long get() { - return Stream.concat(modelManager.getAllModels().stream(), cache.get().getAllModels().stream()).count(); + return Stream + .concat( + Stream.concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()), + forecastCache.get().getAllModels().stream() + ) + .count(); } } diff --git a/src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeSupplier.java new file mode 100644 index 000000000..9833a6433 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/ModelsOnNodeSupplier.java @@ -0,0 +1,125 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.stats.suppliers; + +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.forecast.settings.ForecastSettings.FORECAST_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.ml.ModelState.LAST_CHECKPOINT_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.LAST_USED_TIME_KEY; +import static org.opensearch.timeseries.ml.ModelState.MODEL_TYPE_KEY; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.opensearch.ad.caching.ADCacheProvider; +import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.caching.ForecastCacheProvider; +import org.opensearch.forecast.constant.ForecastCommonName; +import org.opensearch.timeseries.constant.CommonName; + +/** + * ModelsOnNodeSupplier provides a List of ModelStates info for the models the nodes contains + */ +public class ModelsOnNodeSupplier implements Supplier>> { + private ADModelManager modelManager; + private ADCacheProvider adCache; + private ForecastCacheProvider forecastCache; + + // the max number of models to return per node. Defaults to 100. + private volatile int adNumModelsToReturn; + private volatile int forecastNumModelsToReturn; + + /** + * Set that contains the model stats that should be exposed. + */ + public static Set MODEL_STATE_STAT_KEYS = new HashSet<>( + Arrays + .asList( + CommonName.MODEL_ID_FIELD, + ADCommonName.DETECTOR_ID_KEY, + MODEL_TYPE_KEY, + CommonName.ENTITY_KEY, + LAST_USED_TIME_KEY, + LAST_CHECKPOINT_TIME_KEY, + ForecastCommonName.FORECASTER_ID_KEY + ) + ); + + /** + * Constructor + * + * @param modelManager object that manages the model partitions hosted on the node + * @param cache object that manages multi-entity detectors' models + * @param settings node settings accessor + * @param clusterService Cluster service accessor + */ + public ModelsOnNodeSupplier( + ADModelManager modelManager, + ADCacheProvider adCache, + ForecastCacheProvider forecastCache, + Settings settings, + ClusterService clusterService + ) { + this.modelManager = modelManager; + this.adCache = adCache; + this.forecastCache = forecastCache; + this.adNumModelsToReturn = AD_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_MAX_MODEL_SIZE_PER_NODE, it -> this.adNumModelsToReturn = it); + this.forecastNumModelsToReturn = FORECAST_MAX_MODEL_SIZE_PER_NODE.get(settings); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(FORECAST_MAX_MODEL_SIZE_PER_NODE, it -> this.forecastNumModelsToReturn = it); + } + + @Override + public List> get() { + Stream> adStream = Stream + .concat( + Stream.concat(modelManager.getAllModels().stream(), adCache.get().getAllModels().stream()), + forecastCache.get().getAllModels().stream() + ) + .limit(adNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + Stream> forecastStream = forecastCache + .get() + .getAllModels() + .stream() + .limit(forecastNumModelsToReturn) + .map( + modelState -> modelState + .getModelStateAsMap() + .entrySet() + .stream() + .filter(entry -> MODEL_STATE_STAT_KEYS.contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + + return Stream.concat(adStream, forecastStream).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java similarity index 94% rename from src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java rename to src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java index b39ecdde5..e5e60c6ba 100644 --- a/src/main/java/org/opensearch/ad/stats/suppliers/SettableSupplier.java +++ b/src/main/java/org/opensearch/timeseries/stats/suppliers/SettableSupplier.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.stats.suppliers; +package org.opensearch.timeseries.stats.suppliers; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; diff --git a/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java similarity index 79% rename from src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java rename to src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java index bf8cbb860..25f669b41 100644 --- a/src/main/java/org/opensearch/ad/task/ADRealtimeTaskCache.java +++ b/src/main/java/org/opensearch/timeseries/task/RealtimeTaskCache.java @@ -9,19 +9,19 @@ * GitHub history for details. */ -package org.opensearch.ad.task; +package org.opensearch.timeseries.task; import java.time.Instant; /** - * AD realtime task cache which will hold these data + *realtime task cache which will hold these data * 1. task state * 2. init progress * 3. error * 4. last job run time - * 5. detector interval + * 5. interval */ -public class ADRealtimeTaskCache { +public class RealtimeTaskCache { // task state private String state; @@ -35,19 +35,19 @@ public class ADRealtimeTaskCache { // track last job run time, will clean up cache if no access after 2 intervals private long lastJobRunTime; - // detector interval in milliseconds. - private long detectorIntervalInMillis; + // interval in milliseconds. + private long intervalInMillis; // we query result index to check if there are any result generated for detector to tell whether it passed initialization of not. // To avoid repeated query when there is no data, record whether we have done that or not. private boolean queriedResultIndex; - public ADRealtimeTaskCache(String state, Float initProgress, String error, long detectorIntervalInMillis) { + public RealtimeTaskCache(String state, Float initProgress, String error, long detectorIntervalInMillis) { this.state = state; this.initProgress = initProgress; this.error = error; this.lastJobRunTime = Instant.now().toEpochMilli(); - this.detectorIntervalInMillis = detectorIntervalInMillis; + this.intervalInMillis = detectorIntervalInMillis; this.queriedResultIndex = false; } @@ -88,6 +88,6 @@ public void setQueriedResultIndex(boolean queriedResultIndex) { } public boolean expired() { - return lastJobRunTime + 2 * detectorIntervalInMillis < Instant.now().toEpochMilli(); + return lastJobRunTime + 2 * intervalInMillis < Instant.now().toEpochMilli(); } } diff --git a/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java new file mode 100644 index 000000000..d7b363112 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/task/TaskCacheManager.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.task; + +import static org.opensearch.timeseries.settings.TimeSeriesSettings.MAX_CACHED_DELETED_TASKS; + +import java.time.Instant; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.ad.model.AnomalyDetector; +import org.opensearch.ad.task.ADTaskManager; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.forecast.task.ForecastTaskManager; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.transport.TransportService; + +public class TaskCacheManager { + private final Logger logger = LogManager.getLogger(TaskCacheManager.class); + /** + * This field is to cache all realtime tasks on coordinating node. + *

Node: coordinating node

+ *

Key is config id

+ */ + private Map realtimeTaskCaches; + + /** + * This field is to cache all deleted config level tasks on coordinating node. + * Will try to clean up child task and result later. + *

Node: coordinating node

+ * Check {@link ForecastTaskManager#cleanChildTasksAndResultsOfDeletedTask()} + */ + private Queue deletedTasks; + + protected volatile Integer maxCachedDeletedTask; + + public TaskCacheManager(Settings settings, ClusterService clusterService) { + this.realtimeTaskCaches = new ConcurrentHashMap<>(); + this.deletedTasks = new ConcurrentLinkedQueue<>(); + this.maxCachedDeletedTask = MAX_CACHED_DELETED_TASKS.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_CACHED_DELETED_TASKS, it -> maxCachedDeletedTask = it); + } + + public RealtimeTaskCache getRealtimeTaskCache(String configId) { + return realtimeTaskCaches.get(configId); + } + + public void initRealtimeTaskCache(String configId, long configIntervalInMillis) { + realtimeTaskCaches.put(configId, new RealtimeTaskCache(null, null, null, configIntervalInMillis)); + logger.debug("Realtime task cache inited"); + } + + /** + * Add deleted task's id to deleted detector tasks queue. + * @param taskId task id + */ + public void addDeletedTask(String taskId) { + if (deletedTasks.size() < maxCachedDeletedTask) { + deletedTasks.add(taskId); + } + } + + /** + * Check if deleted task queue has items. + * @return true if has deleted detector task in cache + */ + public boolean hasDeletedTask() { + return !deletedTasks.isEmpty(); + } + + /** + * Poll one deleted forecaster task. + * @return task id + */ + public String pollDeletedTask() { + return this.deletedTasks.poll(); + } + + /** + * Clear realtime task cache. + */ + public void clearRealtimeTaskCache() { + realtimeTaskCaches.clear(); + } + + /** + * Check if realtime task field value change needed or not by comparing with cache. + * 1. If new field value is null, will consider changed needed to this field. + * 2. will consider the real time task change needed if + * 1) init progress is larger or the old init progress is null, or + * 2) if the state is different, and it is not changing from running to init. + * for other fields, as long as field values changed, will consider the realtime + * task change needed. We did this so that the init progress or state won't go backwards. + * 3. If realtime task cache not found, will consider the realtime task change needed. + * + * @param detectorId detector id + * @param newState new task state + * @param newInitProgress new init progress + * @param newError new error + * @return true if realtime task change needed. + */ + public boolean isRealtimeTaskChangeNeeded(String detectorId, String newState, Float newInitProgress, String newError) { + if (realtimeTaskCaches.containsKey(detectorId)) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + boolean stateChangeNeeded = false; + String oldState = realtimeTaskCache.getState(); + if (newState != null + && !newState.equals(oldState) + && !(TaskState.INIT.name().equals(newState) && TaskState.RUNNING.name().equals(oldState))) { + stateChangeNeeded = true; + } + boolean initProgressChangeNeeded = false; + Float existingProgress = realtimeTaskCache.getInitProgress(); + if (newInitProgress != null + && !newInitProgress.equals(existingProgress) + && (existingProgress == null || newInitProgress > existingProgress)) { + initProgressChangeNeeded = true; + } + boolean errorChanged = false; + if (newError != null && !newError.equals(realtimeTaskCache.getError())) { + errorChanged = true; + } + if (stateChangeNeeded || initProgressChangeNeeded || errorChanged) { + return true; + } + return false; + } else { + return true; + } + } + + /** + * Update realtime task cache with new field values. If realtime task cache exist, update it + * directly if task is not done; if task is done, remove the detector's realtime task cache. + * + * If realtime task cache doesn't exist, will do nothing. Next realtime job run will re-init + * realtime task cache when it finds task cache not inited yet. + * Check {@link ADTaskManager#initCacheWithCleanupIfRequired(String, AnomalyDetector, TransportService, ActionListener)}, + * {@link ADTaskManager#updateLatestRealtimeTaskOnCoordinatingNode(String, String, Long, Long, String, ActionListener)} + * + * @param detectorId detector id + * @param newState new task state + * @param newInitProgress new init progress + * @param newError new error + */ + public void updateRealtimeTaskCache(String detectorId, String newState, Float newInitProgress, String newError) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(detectorId); + if (realtimeTaskCache != null) { + if (newState != null) { + realtimeTaskCache.setState(newState); + } + if (newInitProgress != null) { + realtimeTaskCache.setInitProgress(newInitProgress); + } + if (newError != null) { + realtimeTaskCache.setError(newError); + } + if (newState != null && !TaskState.NOT_ENDED_STATES.contains(newState)) { + // If task is done, will remove its realtime task cache. + logger.info("Realtime task done with state {}, remove RT task cache for detector ", newState, detectorId); + removeRealtimeTaskCache(detectorId); + } + } else { + logger.debug("Realtime task cache is not inited yet for detector {}", detectorId); + } + } + + public void refreshRealtimeJobRunTime(String detectorId) { + RealtimeTaskCache taskCache = realtimeTaskCaches.get(detectorId); + if (taskCache != null) { + taskCache.setLastJobRunTime(Instant.now().toEpochMilli()); + } + } + + /** + * Get detector IDs from realtime task cache. + * @return array of detector id + */ + public String[] getDetectorIdsInRealtimeTaskCache() { + return realtimeTaskCaches.keySet().toArray(new String[0]); + } + + /** + * Remove detector's realtime task from cache. + * @param detectorId detector id + */ + public void removeRealtimeTaskCache(String detectorId) { + if (realtimeTaskCaches.containsKey(detectorId)) { + logger.info("Delete realtime cache for detector {}", detectorId); + realtimeTaskCaches.remove(detectorId); + } + } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * To avoid repeated query when there is no data, record whether we have done that or not. + * @param id detector id + */ + public void markResultIndexQueried(String id) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + // we initialize a real time cache at the beginning of AnomalyResultTransportAction if it + // cannot be found. If the cache is empty, we will return early and wait it for it to be + // initialized. + if (realtimeTaskCache != null) { + realtimeTaskCache.setQueriedResultIndex(true); + } + } + + /** + * We query result index to check if there are any result generated for detector to tell whether it passed initialization of not. + * + * @param id detector id + * @return whether we have queried result index or not. + */ + public boolean hasQueriedResultIndex(String id) { + RealtimeTaskCache realtimeTaskCache = realtimeTaskCaches.get(id); + if (realtimeTaskCache != null) { + return realtimeTaskCache.hasQueriedResultIndex(); + } + return false; + } +} diff --git a/src/main/java/org/opensearch/timeseries/task/TaskManager.java b/src/main/java/org/opensearch/timeseries/task/TaskManager.java new file mode 100644 index 000000000..e1e8342f6 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/task/TaskManager.java @@ -0,0 +1,1032 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.task; + +import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES; +import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.timeseries.constant.CommonMessages.CONFIG_IS_RUNNING; +import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES; +import static org.opensearch.timeseries.model.TaskType.taskTypeToString; +import static org.opensearch.timeseries.util.RestHandlerUtils.XCONTENT_WITH_TYPE; +import static org.opensearch.timeseries.util.RestHandlerUtils.createXContentParserFromRegistry; + +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.search.join.ScoreMode; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.NestedQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.index.reindex.DeleteByQueryAction; +import org.opensearch.index.reindex.DeleteByQueryRequest; +import org.opensearch.index.reindex.UpdateByQueryAction; +import org.opensearch.index.reindex.UpdateByQueryRequest; +import org.opensearch.script.Script; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.common.exception.DuplicateTaskException; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TaskCancelledException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.function.BiCheckedFunction; +import org.opensearch.timeseries.function.ExecutorFunction; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +import com.google.common.collect.ImmutableMap; + +public abstract class TaskManager & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + protected static int DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS = 5; + + private final Logger logger = LogManager.getLogger(TaskManager.class); + + protected final TaskCacheManagerType taskCacheManager; + protected final ClusterService clusterService; + protected final Client client; + protected final String stateIndex; + private final List realTimeTaskTypes; + protected final IndexManagementType indexManagement; + protected final NodeStateManager nodeStateManager; + protected final AnalysisType analysisType; + protected final NamedXContentRegistry xContentRegistry; + protected final String configIdFieldName; + + protected volatile Integer maxOldAdTaskDocsPerConfig; + + protected final ThreadPool threadPool; + private final String allResultIndexPattern; + private final String batchTaskThreadPoolName; + + public TaskManager( + TaskCacheManagerType taskCacheManager, + ClusterService clusterService, + Client client, + String stateIndex, + List realTimeTaskTypes, + IndexManagementType indexManagement, + NodeStateManager nodeStateManager, + AnalysisType analysisType, + NamedXContentRegistry xContentRegistry, + String configIdFieldName, + Setting maxOldADTaskDocsPerConfig, + Settings settings, + ThreadPool threadPool, + String allResultIndexPattern, + String batchTaskThreadPoolName + ) { + this.taskCacheManager = taskCacheManager; + this.clusterService = clusterService; + this.client = client; + this.stateIndex = stateIndex; + this.realTimeTaskTypes = realTimeTaskTypes; + this.indexManagement = indexManagement; + this.nodeStateManager = nodeStateManager; + this.analysisType = analysisType; + this.xContentRegistry = xContentRegistry; + this.configIdFieldName = configIdFieldName; + + this.maxOldAdTaskDocsPerConfig = maxOldADTaskDocsPerConfig.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxOldADTaskDocsPerConfig, it -> maxOldAdTaskDocsPerConfig = it); + + this.threadPool = threadPool; + this.allResultIndexPattern = allResultIndexPattern; + this.batchTaskThreadPoolName = batchTaskThreadPoolName; + } + + public boolean skipUpdateRealtimeTask(String configId, String error) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(configId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() == 1.0 + && Objects.equals(error, realtimeTaskCache.getError()); + } + + public boolean isHCRealtimeTaskStartInitializing(String detectorId) { + RealtimeTaskCache realtimeTaskCache = taskCacheManager.getRealtimeTaskCache(detectorId); + return realtimeTaskCache != null + && realtimeTaskCache.getInitProgress() != null + && realtimeTaskCache.getInitProgress().floatValue() > 0; + } + + /** + * Maintain running realtime tasks. Check if realtime task cache expires or not. Remove realtime + * task cache directly if expired. + */ + public void maintainRunningRealtimeTasks() { + String[] detectorIds = taskCacheManager.getDetectorIdsInRealtimeTaskCache(); + if (detectorIds == null || detectorIds.length == 0) { + return; + } + for (int i = 0; i < detectorIds.length; i++) { + String detectorId = detectorIds[i]; + RealtimeTaskCache taskCache = taskCacheManager.getRealtimeTaskCache(detectorId); + if (taskCache != null && taskCache.expired()) { + taskCacheManager.removeRealtimeTaskCache(detectorId); + } + } + } + + public void refreshRealtimeJobRunTime(String detectorId) { + taskCacheManager.refreshRealtimeJobRunTime(detectorId); + } + + public void removeRealtimeTaskCache(String detectorId) { + taskCacheManager.removeRealtimeTaskCache(detectorId); + } + + /** + * Update realtime task cache on realtime config's coordinating node. + * + * @param configId config id + * @param state new state + * @param rcfTotalUpdates rcf total updates + * @param intervalInMinutes config interval in minutes + * @param error error + * @param listener action listener + */ + public void updateLatestRealtimeTaskOnCoordinatingNode( + String configId, + String state, + Long rcfTotalUpdates, + Long intervalInMinutes, + String error, + ActionListener listener + ) { + Float initProgress = null; + String newState = null; + // calculate init progress and task state with RCF total updates + if (intervalInMinutes != null && rcfTotalUpdates != null) { + newState = TaskState.INIT.name(); + if (rcfTotalUpdates < TimeSeriesSettings.NUM_MIN_SAMPLES) { + initProgress = (float) rcfTotalUpdates / TimeSeriesSettings.NUM_MIN_SAMPLES; + } else { + newState = TaskState.RUNNING.name(); + initProgress = 1.0f; + } + } + // Check if new state is not null and override state calculated with rcf total updates + if (state != null) { + newState = state; + } + + error = Optional.ofNullable(error).orElse(""); + if (!taskCacheManager.isRealtimeTaskChangeNeeded(configId, newState, initProgress, error)) { + // If task not changed, no need to update, just return + listener.onResponse(null); + return; + } + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.COORDINATING_NODE_FIELD, clusterService.localNode().getId()); + if (initProgress != null) { + updatedFields.put(TimeSeriesTask.INIT_PROGRESS_FIELD, initProgress); + updatedFields + .put( + TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD, + Math.max(0, TimeSeriesSettings.NUM_MIN_SAMPLES - rcfTotalUpdates) * intervalInMinutes + ); + } + if (newState != null) { + updatedFields.put(TimeSeriesTask.STATE_FIELD, newState); + } + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error); + } + Float finalInitProgress = initProgress; + // Variable used in lambda expression should be final or effectively final + String finalError = error; + String finalNewState = newState; + updateLatestTask(configId, realTimeTaskTypes, updatedFields, ActionListener.wrap(r -> { + logger.debug("Updated latest realtime AD task successfully for detector {}", configId); + taskCacheManager.updateRealtimeTaskCache(configId, finalNewState, finalInitProgress, finalError); + listener.onResponse(r); + }, e -> { + logger.error("Failed to update realtime task for detector " + configId, e); + listener.onFailure(e); + })); + } + + /** + * Update latest task of a config. + * + * @param configId config id + * @param taskTypes task types + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateLatestTask( + String configId, + List taskTypes, + Map updatedFields, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, taskTypes, (task) -> { + if (task.isPresent()) { + updateTask(task.get().getTaskId(), updatedFields, listener); + } else { + listener.onFailure(new ResourceNotFoundException(configId, CommonMessages.CAN_NOT_FIND_LATEST_TASK)); + } + }, null, false, listener); + } + + public void getAndExecuteOnLatestConfigLevelTask( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(config.getId(), getTaskTypes(dateRange), (adTask) -> { + if (!adTask.isPresent() || adTask.get().isDone()) { + updateLatestFlagOfOldTasksAndCreateNewTask(config, dateRange, user, listener); + } else { + listener.onFailure(new OpenSearchStatusException(CONFIG_IS_RUNNING, RestStatus.BAD_REQUEST)); + } + }, transportService, true, listener); + } + + public void updateLatestFlagOfOldTasksAndCreateNewTask( + Config config, + DateRange dateRange, + User user, + ActionListener listener + ) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, config.getId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + // make sure we reset all latest task as false when user switch from single entity to HC, vice versa. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(getTaskTypes(dateRange, true)))); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s=%s;", TimeSeriesTask.IS_LATEST_FIELD, false); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (bulkFailures.isEmpty()) { + // Realtime AD coordinating node is chosen by job scheduler, we won't know it until realtime AD job + // runs. Just set realtime AD coordinating node as null here, and AD job runner will reset correct + // coordinating node once realtime job starts. + // For historical analysis, this method will be called on coordinating node, so we can set coordinating + // node as local node. + String coordinatingNode = dateRange == null ? null : clusterService.localNode().getId(); + createNewTask(config, dateRange, user, coordinatingNode, listener); + } else { + logger.error("Failed to update old task's state for detector: {}, response: {} ", config.getId(), r.toString()); + listener.onFailure(bulkFailures.get(0).getCause()); + } + }, e -> { + logger.error("Failed to reset old tasks as not latest for detector " + config.getId(), e); + listener.onFailure(e); + })); + } + + /** + * Get latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestConfigLevelTask( + String configId, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestADTask(configId, null, null, taskTypes, function, transportService, resetTaskState, listener); + } + + /** + * Get one latest task and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param configId config id + * @param parentTaskId parent task id + * @param entity entity value + * @param taskTypes task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param listener action listener + * @param action listener response type + */ + public void getAndExecuteOnLatestADTask( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + ActionListener listener + ) { + getAndExecuteOnLatestTasks(configId, parentTaskId, entity, taskTypes, (taskList) -> { + if (taskList != null && taskList.size() > 0) { + function.accept(Optional.ofNullable(taskList.get(0))); + } else { + function.accept(Optional.empty()); + } + }, transportService, resetTaskState, 1, listener); + } + + public List getTaskTypes(DateRange dateRange) { + return getTaskTypes(dateRange, false); + } + + /** + * Update latest realtime task. + * + * @param configId config id + * @param state task state + * @param error error + * @param transportService transport service + * @param listener action listener + */ + public void stopLatestRealtimeTask( + String configId, + TaskState state, + Exception error, + TransportService transportService, + ActionListener listener + ) { + getAndExecuteOnLatestConfigLevelTask(configId, getRealTimeTaskTypes(), (adTask) -> { + if (adTask.isPresent() && !adTask.get().isDone()) { + Map updatedFields = new HashMap<>(); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state.name()); + if (error != null) { + updatedFields.put(TimeSeriesTask.ERROR_FIELD, error.getMessage()); + } + ExecutorFunction function = () -> updateTask(adTask.get().getTaskId(), updatedFields, ActionListener.wrap(r -> { + if (error == null) { + listener.onResponse(new JobResponse(configId)); + } else { + listener.onFailure(error); + } + }, e -> { listener.onFailure(e); })); + + String coordinatingNode = adTask.get().getCoordinatingNode(); + if (coordinatingNode != null && transportService != null) { + cleanConfigCache(adTask.get(), transportService, function, listener); + } else { + function.execute(); + } + } else { + listener.onFailure(new OpenSearchStatusException("job is already stopped: " + configId, RestStatus.OK)); + } + }, null, false, listener); + } + + protected void resetTaskStateAsStopped( + TimeSeriesTask task, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + cleanConfigCache(task, transportService, () -> { + String taskId = task.getTaskId(); + Map updatedFields = ImmutableMap.of(TimeSeriesTask.STATE_FIELD, TaskState.STOPPED.name()); + updateTask(taskId, updatedFields, ActionListener.wrap(r -> { + task.setState(TaskState.STOPPED.name()); + if (function != null) { + function.execute(); + } + // For realtime anomaly detection, we only create config level task, no entity level realtime task. + if (isHistoricalHCTask(task)) { + // Reset running entity tasks as STOPPED + resetEntityTasksAsStopped(taskId); + } + }, e -> { + logger.error("Failed to update task state as STOPPED for task " + taskId, e); + listener.onFailure(e); + })); + }, listener); + } + + /** + * the function initializes the cache and only performs cleanup if it is deemed necessary. + * @param id config id + * @param config config accessor + * @param transportService Transport service + * @param listener listener to return back init success or not + */ + public abstract void initCacheWithCleanupIfRequired( + String id, + Config config, + TransportService transportService, + ActionListener listener + ); + + /** + * Get latest config tasks and execute consumer function. + * [Important!] Make sure listener returns in function + * + * @param detectorId detector id + * @param parentTaskId parent task id + * @param entity entity value + * @param adTaskTypes AD task types + * @param function consumer function + * @param transportService transport service + * @param resetTaskState reset task state or not + * @param size return how many AD tasks + * @param listener action listener + * @param response type of action listener + */ + public void getAndExecuteOnLatestTasks( + String configId, + String parentTaskId, + Entity entity, + List taskTypes, + Consumer> function, + TransportService transportService, + boolean resetTaskState, + int size, + ActionListener listener + ) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, configId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, true)); + if (parentTaskId != null) { + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); + } + if (taskTypes != null && taskTypes.size() > 0) { + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, TaskType.taskTypeToString(taskTypes))); + } + if (entity != null && !ParseUtils.isNullOrEmpty(entity.getAttributes())) { + String path = "entity"; + String entityKeyFieldName = path + ".name"; + String entityValueFieldName = path + ".value"; + + for (Map.Entry attribute : entity.getAttributes().entrySet()) { + BoolQueryBuilder entityBoolQuery = new BoolQueryBuilder(); + TermQueryBuilder entityKeyFilterQuery = QueryBuilders.termQuery(entityKeyFieldName, attribute.getKey()); + TermQueryBuilder entityValueFilterQuery = QueryBuilders.termQuery(entityValueFieldName, attribute.getValue()); + + entityBoolQuery.filter(entityKeyFilterQuery).filter(entityValueFilterQuery); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(path, entityBoolQuery, ScoreMode.None); + query.filter(nestedQueryBuilder); + } + } + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).size(size); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(sourceBuilder); + searchRequest.indices(stateIndex); + + client.search(searchRequest, ActionListener.wrap(r -> { + // https://github.com/opendistro-for-elasticsearch/anomaly-detection/pull/359#discussion_r558653132 + // getTotalHits will be null when we track_total_hits is false in the query request. + // Add more checking here to cover some unknown cases. + List tsTasks = new ArrayList<>(); + if (r == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { + // don't throw exception here as consumer functions need to handle missing task + // in different way. + function.accept(tsTasks); + return; + } + BiCheckedFunction parserMethod = getTaskParser(); + Iterator iterator = r.getHits().iterator(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + TaskClass tsTask = parserMethod.apply(parser, searchHit.getId()); + tsTasks.add(tsTask); + } catch (Exception e) { + String message = "Failed to parse task for config " + configId + ", task id " + searchHit.getId(); + logger.error(message, e); + listener.onFailure(new OpenSearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR)); + } + } + if (resetTaskState) { + resetLatestConfigTaskState(tsTasks, function, transportService, listener); + } else { + function.accept(tsTasks); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.accept(new ArrayList<>()); + } else { + logger.error("Failed to search task for config " + configId, e); + listener.onFailure(e); + } + })); + } + + /** + * Reset latest config task state. Will reset both historical and realtime tasks. + * [Important!] Make sure listener returns in function + * + * @param adTasks ad tasks + * @param function consumer function + * @param transportService transport service + * @param listener action listener + * @param response type of action listener + */ + protected void resetLatestConfigTaskState( + List adTasks, + Consumer> function, + TransportService transportService, + ActionListener listener + ) { + List runningHistoricalTasks = new ArrayList<>(); + List runningRealtimeTasks = new ArrayList<>(); + for (TimeSeriesTask adTask : adTasks) { + if (!adTask.isEntityTask() && !adTask.isDone()) { + if (!adTask.isHistoricalTask()) { + // try to reset task state if realtime task is not ended + runningRealtimeTasks.add(adTask); + } else { + // try to reset task state if historical task not updated for 2 piece intervals + runningHistoricalTasks.add(adTask); + } + } + } + + resetHistoricalConfigTaskState( + runningHistoricalTasks, + () -> resetRealtimeConfigTaskState(runningRealtimeTasks, () -> function.accept(adTasks), transportService, listener), + transportService, + listener + ); + } + + private void resetRealtimeConfigTaskState( + List runningRealtimeTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ) { + if (ParseUtils.isNullOrEmpty(runningRealtimeTasks)) { + function.execute(); + return; + } + TimeSeriesTask tsTask = runningRealtimeTasks.get(0); + String configId = tsTask.getConfigId(); + GetRequest getJobRequest = new GetRequest(CommonName.JOB_INDEX).id(configId); + client.get(getJobRequest, ActionListener.wrap(r -> { + if (r.isExists()) { + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + Job job = Job.parse(parser); + if (!job.isEnabled()) { + logger.debug("job is disabled, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } else { + function.execute(); + } + } catch (IOException e) { + logger.error(" Failed to parse job " + configId, e); + listener.onFailure(e); + } + } else { + logger.debug("job is not found, reset realtime task as stopped for config {}", configId); + resetTaskStateAsStopped(tsTask, function, transportService, listener); + } + }, e -> { + logger.error("Fail to get realtime job for config " + configId, e); + listener.onFailure(e); + })); + } + + /** + * Handle exceptions for task. Update task state and record error message. + * + * @param task AD task + * @param e exception + */ + public void handleTaskException(TaskClass task, Exception e) { + // TODO: handle timeout exception + String state = TaskState.FAILED.name(); + Map updatedFields = new HashMap<>(); + if (e instanceof DuplicateTaskException) { + // If user send multiple start detector request, we will meet race condition. + // Cache manager will put first request in cache and throw DuplicateTaskException + // for the second request. We will delete the second task. + logger + .warn( + "There is already one running task for config, configId:" + + task.getConfigId() + + ". Will delete task " + + task.getTaskId() + ); + deleteTask(task.getTaskId()); + return; + } + if (e instanceof TaskCancelledException) { + logger.info("task cancelled, taskId: {}, configId: {}", task.getTaskId(), task.getConfigId()); + state = TaskState.STOPPED.name(); + String stoppedBy = ((TaskCancelledException) e).getCancelledBy(); + if (stoppedBy != null) { + updatedFields.put(TimeSeriesTask.STOPPED_BY_FIELD, stoppedBy); + } + } else { + logger.error("Failed to execute batch task, task id: " + task.getTaskId() + ", config id: " + task.getConfigId(), e); + } + updatedFields.put(TimeSeriesTask.ERROR_FIELD, ExceptionUtil.getErrorMessage(e)); + updatedFields.put(TimeSeriesTask.STATE_FIELD, state); + updatedFields.put(TimeSeriesTask.EXECUTION_END_TIME_FIELD, Instant.now().toEpochMilli()); + updateTask(task.getTaskId(), updatedFields); + } + + /** + * Update task with specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + */ + public void updateTask(String taskId, Map updatedFields) { + updateTask(taskId, updatedFields, ActionListener.wrap(response -> { + if (response.status() == RestStatus.OK) { + logger.debug("Updated task successfully: {}, task id: {}", response.status(), taskId); + } else { + logger.error("Failed to update task {}, status: {}", taskId, response.status()); + } + }, e -> { logger.error("Failed to update task: " + taskId, e); })); + } + + /** + * Update AD task for specific fields. + * + * @param taskId task id + * @param updatedFields updated fields, key: filed name, value: new value + * @param listener action listener + */ + public void updateTask(String taskId, Map updatedFields, ActionListener listener) { + UpdateRequest updateRequest = new UpdateRequest(stateIndex, taskId); + Map updatedContent = new HashMap<>(); + updatedContent.putAll(updatedFields); + updatedContent.put(TimeSeriesTask.LAST_UPDATE_TIME_FIELD, Instant.now().toEpochMilli()); + updateRequest.doc(updatedContent); + updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.update(updateRequest, listener); + } + + /** + * Delete task with task id. + * + * @param taskId task id + */ + public void deleteTask(String taskId) { + deleteTask( + taskId, + ActionListener + .wrap( + r -> { logger.info("Deleted task {} with status: {}", taskId, r.status()); }, + e -> { logger.error("Failed to delete task " + taskId, e); } + ) + ); + } + + /** + * Delete task with task id. + * + * @param taskId task id + * @param listener action listener + */ + public void deleteTask(String taskId, ActionListener listener) { + DeleteRequest deleteRequest = new DeleteRequest(stateIndex, taskId); + client.delete(deleteRequest, listener); + } + + /** + * Create config task directly without checking index exists of not. + * [Important!] Make sure listener returns in function + * + * @param tsTask Time series task + * @param function consumer function + * @param listener action listener + * @param action listener response type + */ + public void createTaskDirectly(TaskClass tsTask, Consumer function, ActionListener listener) { + IndexRequest request = new IndexRequest(stateIndex); + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + request + .source(tsTask.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(request, ActionListener.wrap(r -> function.accept(r), e -> { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + })); + } catch (Exception e) { + logger.error("Failed to create task for config " + tsTask.getConfigId(), e); + listener.onFailure(e); + } + } + + protected void cleanOldConfigTaskDocs(IndexResponse response, TaskClass tsTask, ActionListener delegatedListener) { + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(configIdFieldName, tsTask.getConfigId())); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, false)); + + if (tsTask.isHistoricalTask()) { + // If historical task, only delete detector level task. It may take longer time to delete entity tasks. + // We will delete child task (entity task) of detector level task in hourly cron job. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(HISTORICAL_DETECTOR_TASK_TYPES))); + } else { + // We don't have entity level task for realtime detection, so will delete all tasks. + query.filter(new TermsQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, taskTypeToString(REALTIME_TASK_TYPES))); + } + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder + .query(query) + .sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC) + // Search query "from" starts from 0. + .from(maxOldAdTaskDocsPerConfig) + .size(MAX_OLD_AD_TASK_DOCS); + searchRequest.source(sourceBuilder).indices(DETECTION_STATE_INDEX); + String detectorId = tsTask.getConfigId(); + + deleteTaskDocs(detectorId, searchRequest, () -> { + if (tsTask.isHistoricalTask()) { + // run batch result action for historical detection + runBatchResultAction(response, tsTask, delegatedListener); + } else { + // return response directly for realtime detection + JobResponse anomalyDetectorJobResponse = new JobResponse(response.getId()); + delegatedListener.onResponse(anomalyDetectorJobResponse); + } + }, delegatedListener); + } + + protected void deleteTaskDocs( + String detectorId, + SearchRequest searchRequest, + ExecutorFunction function, + ActionListener listener + ) { + ActionListener searchListener = ActionListener.wrap(r -> { + Iterator iterator = r.getHits().iterator(); + if (iterator.hasNext()) { + BulkRequest bulkRequest = new BulkRequest(); + while (iterator.hasNext()) { + SearchHit searchHit = iterator.next(); + try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, searchHit.getSourceRef())) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + ADTask adTask = ADTask.parse(parser, searchHit.getId()); + logger.debug("Delete old task: {} of detector: {}", adTask.getTaskId(), adTask.getConfigId()); + bulkRequest.add(new DeleteRequest(DETECTION_STATE_INDEX).id(adTask.getTaskId())); + } catch (Exception e) { + listener.onFailure(e); + } + } + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + logger.info("Old AD tasks deleted for detector {}", detectorId); + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.debug("Add detector task into cache. Task id: {}", bulkItemResponse.getId()); + // add deleted task in cache and delete its child tasks and AD results + taskCacheManager.addDeletedTask(bulkItemResponse.getId()); + } + } + } + // delete child tasks and results of this task + cleanChildTasksAndADResultsOfDeletedTask(); + + function.execute(); + }, e -> { + logger.warn("Failed to clean tasks for config " + detectorId, e); + listener.onFailure(e); + })); + } else { + function.execute(); + } + }, e -> { + if (e instanceof IndexNotFoundException) { + function.execute(); + } else { + listener.onFailure(e); + } + }); + + client.search(searchRequest, searchListener); + } + + /** + * Poll deleted detector task from cache and delete its child tasks and AD results. + */ + public void cleanChildTasksAndADResultsOfDeletedTask() { + if (!taskCacheManager.hasDeletedTask()) { + return; + } + threadPool.schedule(() -> { + String taskId = taskCacheManager.pollDeletedTask(); + if (taskId == null) { + return; + } + DeleteByQueryRequest deleteADResultsRequest = new DeleteByQueryRequest(allResultIndexPattern); + deleteADResultsRequest.setQuery(new TermsQueryBuilder(CommonName.TASK_ID_FIELD, taskId)); + client.execute(DeleteByQueryAction.INSTANCE, deleteADResultsRequest, ActionListener.wrap(res -> { + logger.debug("Successfully deleted results of task " + taskId); + DeleteByQueryRequest deleteChildTasksRequest = new DeleteByQueryRequest(DETECTION_STATE_INDEX); + deleteChildTasksRequest.setQuery(new TermsQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, taskId)); + + client.execute(DeleteByQueryAction.INSTANCE, deleteChildTasksRequest, ActionListener.wrap(r -> { + logger.debug("Successfully deleted child tasks of task " + taskId); + cleanChildTasksAndADResultsOfDeletedTask(); + }, e -> { logger.error("Failed to delete child tasks of task " + taskId, e); })); + }, ex -> { logger.error("Failed to delete results for task " + taskId, ex); })); + }, TimeValue.timeValueSeconds(DEFAULT_MAINTAIN_INTERVAL_IN_SECONDS), batchTaskThreadPoolName); + } + + protected void resetEntityTasksAsStopped(String configTaskId) { + UpdateByQueryRequest updateByQueryRequest = new UpdateByQueryRequest(); + updateByQueryRequest.indices(stateIndex); + BoolQueryBuilder query = new BoolQueryBuilder(); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, configTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.TASK_TYPE_FIELD, ADTaskType.AD_HISTORICAL_HC_ENTITY.name())); + query.filter(new TermsQueryBuilder(TimeSeriesTask.STATE_FIELD, NOT_ENDED_STATES)); + updateByQueryRequest.setQuery(query); + updateByQueryRequest.setRefresh(true); + String script = String.format(Locale.ROOT, "ctx._source.%s='%s';", TimeSeriesTask.STATE_FIELD, TaskState.STOPPED.name()); + updateByQueryRequest.setScript(new Script(script)); + + client.execute(UpdateByQueryAction.INSTANCE, updateByQueryRequest, ActionListener.wrap(r -> { + List bulkFailures = r.getBulkFailures(); + if (ParseUtils.isNullOrEmpty(bulkFailures)) { + logger.debug("Updated {} child entity tasks state for config task {}", r.getUpdated(), configTaskId); + } else { + logger.error("Failed to update child entity task's state for config task {} ", configTaskId); + } + }, e -> logger.error("Exception happened when update child entity task's state for config task " + configTaskId, e))); + } + + /** + * Set old task's latest flag as false. + * @param tasks list of tasks + */ + public void resetLatestFlagAsFalse(List tasks) { + if (tasks == null || tasks.size() == 0) { + return; + } + BulkRequest bulkRequest = new BulkRequest(); + tasks.forEach(task -> { + try { + task.setLatest(false); + task.setLastUpdateTime(Instant.now()); + IndexRequest indexRequest = new IndexRequest(stateIndex) + .id(task.getTaskId()) + .source(task.toXContent(XContentBuilder.builder(XContentType.JSON.xContent()), XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (Exception e) { + logger.error("Fail to parse task task to XContent, task id " + task.getTaskId(), e); + } + }); + + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(res -> { + BulkItemResponse[] bulkItemResponses = res.getItems(); + if (bulkItemResponses != null && bulkItemResponses.length > 0) { + for (BulkItemResponse bulkItemResponse : bulkItemResponses) { + if (!bulkItemResponse.isFailed()) { + logger.warn("Reset tasks latest flag as false Successfully. Task id: {}", bulkItemResponse.getId()); + } else { + logger.warn("Failed to reset tasks latest flag as false. Task id: " + bulkItemResponse.getId()); + } + } + } + }, e -> { logger.warn("Failed to reset AD tasks latest flag as false", e); })); + } + + public abstract void startHistorical( + Config config, + DateRange dateRange, + User user, + TransportService transportService, + ActionListener listener + ); + + protected abstract List getTaskTypes(DateRange dateRange, boolean resetLatestTaskStateFlag); + + protected abstract TaskType getTaskType(Config config, DateRange dateRange); + + protected abstract void createNewTask( + Config config, + DateRange dateRange, + User user, + String coordinatingNode, + ActionListener listener + ); + + protected abstract List getRealTimeTaskTypes(); + + public abstract void cleanConfigCache( + TimeSeriesTask task, + TransportService transportService, + ExecutorFunction function, + ActionListener listener + ); + + protected abstract boolean isHistoricalHCTask(TimeSeriesTask task); + + public abstract void stopHistoricalAnalysis( + String detectorId, + Optional adTask, + User user, + ActionListener listener + ); + + protected abstract void resetHistoricalConfigTaskState( + List runningHistoricalTasks, + ExecutorFunction function, + TransportService transportService, + ActionListener listener + ); + + protected abstract void onIndexConfigTaskResponse( + IndexResponse response, + TaskClass adTask, + BiConsumer> function, + ActionListener listener + ); + + protected abstract void runBatchResultAction(IndexResponse response, TaskClass tsTask, ActionListener listener); + + protected abstract BiCheckedFunction getTaskParser(); +} diff --git a/src/main/java/org/opensearch/ad/transport/BackPressureRouting.java b/src/main/java/org/opensearch/timeseries/transport/BackPressureRouting.java similarity index 98% rename from src/main/java/org/opensearch/ad/transport/BackPressureRouting.java rename to src/main/java/org/opensearch/timeseries/transport/BackPressureRouting.java index e5f4ba9b8..bfec0fe95 100644 --- a/src/main/java/org/opensearch/ad/transport/BackPressureRouting.java +++ b/src/main/java/org/opensearch/timeseries/transport/BackPressureRouting.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.time.Clock; import java.util.concurrent.atomic.AtomicInteger; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java similarity index 51% rename from src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java rename to src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java index b7a3bee88..8a638e401 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelTransportAction.java +++ b/src/main/java/org/opensearch/timeseries/transport/BaseDeleteModelTransportAction.java @@ -1,62 +1,56 @@ /* + * Copyright OpenSearch Contributors * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.inject.Inject; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.task.TaskCacheManager; import org.opensearch.transport.TransportService; -public class DeleteModelTransportAction extends - TransportNodesAction { - private static final Logger LOG = LogManager.getLogger(DeleteModelTransportAction.class); +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +public class BaseDeleteModelTransportAction, CacheProviderType extends CacheProvider, TaskCacheManagerType extends TaskCacheManager, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart> + extends TransportNodesAction { + + private static final Logger LOG = LogManager.getLogger(BaseDeleteModelTransportAction.class); private NodeStateManager nodeStateManager; - private ModelManager modelManager; - private FeatureManager featureManager; - private CacheProvider cache; - private ADTaskCacheManager adTaskCacheManager; - private EntityColdStarter coldStarter; - - @Inject - public DeleteModelTransportAction( + private CacheProviderType cache; + private TaskCacheManagerType adTaskCacheManager; + private ModelColdStartType coldStarter; + + public BaseDeleteModelTransportAction( ThreadPool threadPool, ClusterService clusterService, TransportService transportService, ActionFilters actionFilters, NodeStateManager nodeStateManager, - ModelManager modelManager, - FeatureManager featureManager, - CacheProvider cache, - ADTaskCacheManager adTaskCacheManager, - EntityColdStarter coldStarter + CacheProviderType cache, + TaskCacheManagerType taskCacheManager, + ModelColdStartType coldStarter, + String deleteModelAction ) { super( - DeleteModelAction.NAME, + deleteModelAction, threadPool, clusterService, transportService, @@ -67,10 +61,8 @@ public DeleteModelTransportAction( DeleteModelNodeResponse.class ); this.nodeStateManager = nodeStateManager; - this.modelManager = modelManager; - this.featureManager = featureManager; this.cache = cache; - this.adTaskCacheManager = adTaskCacheManager; + this.adTaskCacheManager = taskCacheManager; this.coldStarter = coldStarter; } @@ -104,34 +96,18 @@ protected DeleteModelNodeResponse newNodeResponse(StreamInput in) throws IOExcep @Override protected DeleteModelNodeResponse nodeOperation(DeleteModelNodeRequest request) { - String adID = request.getAdID(); - LOG.info("Delete model for {}", adID); - // delete in-memory models and model checkpoint - modelManager - .clear( - adID, - ActionListener - .wrap( - r -> LOG.info("Deleted model for [{}] with response [{}] ", adID, r), - e -> LOG.error("Fail to delete model for " + adID, e) - ) - ); + String configID = request.getConfigID(); + LOG.info("Delete model for {}", configID); + nodeStateManager.clear(configID); - // delete buffered shingle data - featureManager.clear(adID); + cache.get().clear(configID); - // delete transport state - nodeStateManager.clear(adID); - - cache.get().clear(adID); - - coldStarter.clear(adID); + coldStarter.clear(configID); // delete realtime task cache - adTaskCacheManager.removeRealtimeTaskCache(adID); + adTaskCacheManager.removeRealtimeTaskCache(configID); - LOG.info("Finished deleting {}", adID); + LOG.info("Finished deleting {}", configID); return new DeleteModelNodeResponse(clusterService.localNode()); } - } diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java new file mode 100644 index 000000000..d73dcefbd --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseGetConfigTransportAction.java @@ -0,0 +1,367 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.forecast.constant.ForecastCommonMessages.FAIL_TO_GET_FORECASTER; +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.get.MultiGetItemResponse; +import org.opensearch.action.get.MultiGetRequest; +import org.opensearch.action.get.MultiGetResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.TransportService; + +public abstract class BaseGetConfigTransportAction & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager, ConfigType extends Config> + extends HandledTransportAction { + + private static final Logger LOG = LogManager.getLogger(BaseGetConfigTransportAction.class); + + protected final ClusterService clusterService; + protected final Client client; + protected final SecurityClientUtil clientUtil; + // private final Set allProfileTypeStrs; + // private final Set allProfileTypes; + // private final Set defaultDetectorProfileTypes; + // private final Set allEntityProfileTypeStrs; + // private final Set allEntityProfileTypes; + // private final Set defaultEntityProfileTypes; + protected final NamedXContentRegistry xContentRegistry; + protected final DiscoveryNodeFilterer nodeFilter; + protected final TransportService transportService; + protected volatile Boolean filterByEnabled; + protected final TaskManagerType taskManager; + private final Class configTypeClass; + private final String configParseFieldName; + private final List allTaskTypes; + private final String singleStreamRealTimeTaskName; + private final String hcRealTImeTaskName; + private final String singleStreamHistoricalTaskname; + private final String hcHistoricalTaskName; + private final Class getConfigResponseClass; + + public BaseGetConfigTransportAction( + TransportService transportService, + DiscoveryNodeFilterer nodeFilter, + ActionFilters actionFilters, + ClusterService clusterService, + Client client, + SecurityClientUtil clientUtil, + Settings settings, + NamedXContentRegistry xContentRegistry, + TaskManagerType forecastTaskManager, + String getConfigAction, + Class configTypeClass, + String configParseFieldName, + List allTaskTypes, + String hcRealTImeTaskName, + String singleStreamRealTimeTaskName, + String hcHistoricalTaskName, + String singleStreamHistoricalTaskname, + Setting filterByBackendRoleEnableSetting, + Class getConfigResponseClass + ) { + super(getConfigAction, transportService, actionFilters, GetConfigRequest::new); + this.clusterService = clusterService; + this.client = client; + this.clientUtil = clientUtil; + // List allProfiles = Arrays.asList(DetectorProfileName.values()); + // this.allProfileTypes = EnumSet.copyOf(allProfiles); + // this.allProfileTypeStrs = getProfileListStrs(allProfiles); + // List defaultProfiles = Arrays.asList(DetectorProfileName.ERROR, DetectorProfileName.STATE); + // this.defaultDetectorProfileTypes = new HashSet(defaultProfiles); + // + // List allEntityProfiles = Arrays.asList(EntityProfileName.values()); + // this.allEntityProfileTypes = EnumSet.copyOf(allEntityProfiles); + // this.allEntityProfileTypeStrs = getProfileListStrs(allEntityProfiles); + // List defaultEntityProfiles = Arrays.asList(EntityProfileName.STATE); + // this.defaultEntityProfileTypes = new HashSet(defaultEntityProfiles); + + this.xContentRegistry = xContentRegistry; + this.nodeFilter = nodeFilter; + filterByEnabled = filterByBackendRoleEnableSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleEnableSetting, it -> filterByEnabled = it); + this.transportService = transportService; + this.taskManager = forecastTaskManager; + this.configTypeClass = configTypeClass; + this.configParseFieldName = configParseFieldName; + this.allTaskTypes = allTaskTypes; + this.hcRealTImeTaskName = hcRealTImeTaskName; + this.singleStreamRealTimeTaskName = singleStreamRealTimeTaskName; + this.hcHistoricalTaskName = hcHistoricalTaskName; + this.singleStreamHistoricalTaskname = singleStreamHistoricalTaskname; + this.getConfigResponseClass = getConfigResponseClass; + } + + @Override + protected void doExecute(Task task, GetConfigRequest request, ActionListener actionListener) { + String configID = request.getConfigID(); + User user = ParseUtils.getUserContext(client); + ActionListener listener = wrapRestActionListener(actionListener, FAIL_TO_GET_FORECASTER); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configID, + filterByEnabled, + listener, + (config) -> getExecute(request, listener), + client, + clusterService, + xContentRegistry, + getConfigResponseClass, + configTypeClass + ); + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + protected void getConfigAndJob( + String configID, + boolean returnJob, + boolean returnTask, + Optional realtimeConfigTask, + Optional historicalConfigTask, + ActionListener listener + ) { + MultiGetRequest.Item adItem = new MultiGetRequest.Item(CommonName.CONFIG_INDEX, configID); + MultiGetRequest multiGetRequest = new MultiGetRequest().add(adItem); + if (returnJob) { + MultiGetRequest.Item adJobItem = new MultiGetRequest.Item(CommonName.JOB_INDEX, configID); + multiGetRequest.add(adJobItem); + } + client + .multiGet( + multiGetRequest, + onMultiGetResponse(listener, returnJob, returnTask, realtimeConfigTask, historicalConfigTask, configID) + ); + } + + protected void getExecute(GetConfigRequest request, ActionListener listener) { + String configID = request.getConfigID(); + String typesStr = request.getTypeStr(); + String rawPath = request.getRawPath(); + Entity entity = request.getEntity(); + boolean all = request.isAll(); + boolean returnJob = request.isReturnJob(); + boolean returnTask = request.isReturnTask(); + + try { + if (!Strings.isEmpty(typesStr) || rawPath.endsWith(PROFILE) || rawPath.endsWith(PROFILE + "/")) { + getExecuteProfile(request, entity, typesStr, all, configID, listener); + } else { + if (returnTask) { + taskManager.getAndExecuteOnLatestTasks(configID, null, null, allTaskTypes, (taskList) -> { + Optional realtimeTask = Optional.empty(); + Optional historicalTask = Optional.empty(); + if (taskList != null && taskList.size() > 0) { + Map tasks = new HashMap<>(); + List duplicateTasks = new ArrayList<>(); + for (TaskClass task : taskList) { + if (tasks.containsKey(task.getTaskType())) { + LOG + .info( + "Found duplicate latest task of config {}, task id: {}, task type: {}", + configID, + task.getTaskType(), + task.getTaskId() + ); + duplicateTasks.add(task); + continue; + } + tasks.put(task.getTaskType(), task); + } + if (duplicateTasks.size() > 0) { + taskManager.resetLatestFlagAsFalse(duplicateTasks); + } + + if (tasks.containsKey(hcRealTImeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(hcRealTImeTaskName)); + } else if (tasks.containsKey(singleStreamRealTimeTaskName)) { + realtimeTask = Optional.ofNullable(tasks.get(singleStreamRealTimeTaskName)); + } + if (tasks.containsKey(hcHistoricalTaskName)) { + historicalTask = Optional.ofNullable(tasks.get(hcHistoricalTaskName)); + } else if (tasks.containsKey(singleStreamHistoricalTaskname)) { + historicalTask = Optional.ofNullable(tasks.get(singleStreamHistoricalTaskname)); + } else { + // AD needs to provides custom behavior for bwc, while forecasting can inherit + // the empty implementation + fillInHistoricalTaskforBwc(tasks, historicalTask); + } + } + getConfigAndJob(configID, returnJob, returnTask, realtimeTask, historicalTask, listener); + }, transportService, true, 2, listener); + } else { + getConfigAndJob(configID, returnJob, returnTask, Optional.empty(), Optional.empty(), listener); + } + } + } catch (Exception e) { + LOG.error(e); + listener.onFailure(e); + } + } + + private ActionListener onMultiGetResponse( + ActionListener listener, + boolean returnJob, + boolean returnTask, + Optional realtimeTask, + Optional historicalTask, + String configId + ) { + return new ActionListener() { + @Override + public void onResponse(MultiGetResponse multiGetResponse) { + MultiGetItemResponse[] responses = multiGetResponse.getResponses(); + ConfigType config = null; + Job job = null; + String id = null; + long version = 0; + long seqNo = 0; + long primaryTerm = 0; + + for (MultiGetItemResponse response : responses) { + if (CommonName.CONFIG_INDEX.equals(response.getIndex())) { + if (response.getResponse() == null || !response.getResponse().isExists()) { + listener + .onFailure( + new OpenSearchStatusException(CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId, RestStatus.NOT_FOUND) + ); + return; + } + id = response.getId(); + version = response.getResponse().getVersion(); + primaryTerm = response.getResponse().getPrimaryTerm(); + seqNo = response.getResponse().getSeqNo(); + if (!response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + config = parser.namedObject(configTypeClass, configParseFieldName, null); + } catch (Exception e) { + String message = "Failed to parse config " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } else if (CommonName.JOB_INDEX.equals(response.getIndex())) { + if (response.getResponse() != null + && response.getResponse().isExists() + && !response.getResponse().isSourceEmpty()) { + try ( + XContentParser parser = RestHandlerUtils + .createXContentParserFromRegistry(xContentRegistry, response.getResponse().getSourceAsBytesRef()) + ) { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + job = Job.parse(parser); + System.out.println("job:" + job); + } catch (Exception e) { + String message = "Failed to parse job " + configId; + listener.onFailure(buildInternalServerErrorResponse(e, message)); + return; + } + } + } + } + listener + .onResponse( + createResponse( + version, + id, + primaryTerm, + seqNo, + config, + job, + returnJob, + realtimeTask, + historicalTask, + returnTask, + RestStatus.OK + ) + ); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }; + } + + protected void fillInHistoricalTaskforBwc(Map tasks, Optional historicalAdTask) {} + + protected abstract void getExecuteProfile( + GetConfigRequest request, + Entity entity, + String typesStr, + boolean all, + String configId, + ActionListener listener + ); + + protected abstract ResponseType createResponse( + long version, + String id, + long primaryTerm, + long seqNo, + ConfigType config, + Job job, + boolean returnJob, + Optional realtimeTask, + Optional historicalTask, + boolean returnTask, + RestStatus restStatus + ); + + protected OpenSearchStatusException buildInternalServerErrorResponse(Exception e, String errorMsg) { + LOG.error(errorMsg, e); + return new OpenSearchStatusException(errorMsg, RestStatus.INTERNAL_SERVER_ERROR); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java new file mode 100644 index 000000000..efa9712c9 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/BaseJobTransportAction.java @@ -0,0 +1,147 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; +import static org.opensearch.timeseries.util.RestHandlerUtils.wrapRestActionListener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.tasks.Task; +import org.opensearch.timeseries.ExecuteResultResponseRecorder; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ParseUtils; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +public abstract class BaseJobTransportAction< + IndexType extends Enum & TimeSeriesIndex, + IndexManagementType extends IndexManagement, + TaskCacheManagerType extends TaskCacheManager, + TaskTypeEnum extends TaskType, + TaskClass extends TimeSeriesTask, + TaskManagerType extends TaskManager, + IndexableResultType extends IndexableResult, + ExecuteResultResponseRecorderType extends ExecuteResultResponseRecorder< + IndexType, IndexManagementType, TaskCacheManagerType, TaskTypeEnum, TaskClass, TaskManagerType, IndexableResultType + >, + IndexJobActionHandlerType extends IndexJobActionHandler< + IndexType, IndexManagementType, TaskCacheManagerType, TaskTypeEnum, TaskClass, TaskManagerType, IndexableResultType, ExecuteResultResponseRecorderType + > + > + extends HandledTransportAction { + private final Logger logger = LogManager.getLogger(BaseJobTransportAction.class); + + private final Client client; + private final ClusterService clusterService; + private final Settings settings; + private final NamedXContentRegistry xContentRegistry; + private volatile Boolean filterByEnabled; + private final TransportService transportService; + private final Setting requestTimeOutSetting; + private final String failtoStartMsg; + private final String failtoStopMsg; + private final Class configClass; + private final IndexJobActionHandlerType indexJobActionHandlerType; + + public BaseJobTransportAction( + TransportService transportService, + ActionFilters actionFilters, + Client client, + ClusterService clusterService, + Settings settings, + NamedXContentRegistry xContentRegistry, + Setting filterByBackendRoleSettng, + String jobActionName, + Setting requestTimeOutSetting, + String failtoStartMsg, + String failtoStopMsg, + Class configClass, + IndexJobActionHandlerType indexJobActionHandlerType + ) { + super(jobActionName, transportService, actionFilters, JobRequest::new); + this.transportService = transportService; + this.client = client; + this.clusterService = clusterService; + this.settings = settings; + this.xContentRegistry = xContentRegistry; + filterByEnabled = filterByBackendRoleSettng.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(filterByBackendRoleSettng, it -> filterByEnabled = it); + this.requestTimeOutSetting = requestTimeOutSetting; + this.failtoStartMsg = failtoStartMsg; + this.failtoStopMsg = failtoStopMsg; + this.configClass = configClass; + this.indexJobActionHandlerType = indexJobActionHandlerType; + } + + @Override + protected void doExecute(Task task, JobRequest request, ActionListener actionListener) { + String configId = request.getConfigID(); + DateRange dateRange = request.getDateRange(); + boolean historical = request.isHistorical(); + String rawPath = request.getRawPath(); + TimeValue requestTimeout = requestTimeOutSetting.get(settings); + String errorMessage = rawPath.endsWith(RestHandlerUtils.START_JOB) ? failtoStartMsg : failtoStopMsg; + ActionListener listener = wrapRestActionListener(actionListener, errorMessage); + + // By the time request reaches here, the user permissions are validated by Security plugin. + User user = ParseUtils.getUserContext(client); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + resolveUserAndExecute( + user, + configId, + filterByEnabled, + listener, + (config) -> executeConfig(listener, configId, dateRange, historical, rawPath, requestTimeout, user, context), + client, + clusterService, + xContentRegistry, + JobResponse.class, + configClass + ); + } catch (Exception e) { + logger.error(e); + listener.onFailure(e); + } + } + + private void executeConfig( + ActionListener listener, + String configId, + DateRange dateRange, + boolean historical, + String rawPath, + TimeValue requestTimeout, + User user, + ThreadContext.StoredContext context + ) { + if (rawPath.endsWith(RestHandlerUtils.START_JOB)) { + indexJobActionHandlerType.startConfig(configId, dateRange, user, transportService, context, listener); + } else if (rawPath.endsWith(RestHandlerUtils.STOP_JOB)) { + indexJobActionHandlerType.stopConfig(configId, historical, user, transportService, listener); + } + } +} diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java similarity index 73% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java index d10eef4c3..d0db308c0 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -22,26 +22,26 @@ */ public class DeleteModelNodeRequest extends TransportRequest { - private String adID; + private String configID; DeleteModelNodeRequest() {} DeleteModelNodeRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - DeleteModelNodeRequest(DeleteModelRequest request) { - this.adID = request.getAdID(); + public DeleteModelNodeRequest(DeleteModelRequest request) { + this.configID = request.getAdID(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java similarity index 96% rename from src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java index c71e7368c..a57cb0d30 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelNodeResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelNodeResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java similarity index 77% rename from src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java index 9ec58acda..d6b119e6a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,24 +17,24 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.nodes.BaseNodesRequest; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; /** * Request should be sent from the handler logic of transport delete detector API * */ public class DeleteModelRequest extends BaseNodesRequest implements ToXContentObject { - private String adID; + private String configID; public String getAdID() { - return adID; + return configID; } public DeleteModelRequest() { @@ -43,25 +43,25 @@ public DeleteModelRequest() { public DeleteModelRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } public DeleteModelRequest(String adID, DiscoveryNode... nodes) { super(nodes); - this.adID = adID; + this.configID = adID; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -69,7 +69,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } diff --git a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java similarity index 97% rename from src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java rename to src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java index f2cbe2468..a2154481a 100644 --- a/src/main/java/org/opensearch/ad/transport/DeleteModelResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/DeleteModelResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.List; diff --git a/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java new file mode 100644 index 000000000..b866cac0f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultProcessor.java @@ -0,0 +1,288 @@ +package org.opensearch.timeseries.transport; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.timeseries.caching.CacheProvider; +import org.opensearch.timeseries.caching.TimeSeriesCache; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.CheckpointDao; +import org.opensearch.timeseries.ml.IntermediateResult; +import org.opensearch.timeseries.ml.ModelColdStart; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.ModelState; +import org.opensearch.timeseries.ml.Sample; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.CheckpointReadWorker; +import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker; +import org.opensearch.timeseries.ratelimit.ColdEntityWorker; +import org.opensearch.timeseries.ratelimit.ColdStartWorker; +import org.opensearch.timeseries.ratelimit.FeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.ratelimit.ResultWriteWorker; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.transport.handler.IndexMemoryPressureAwareResultHandler; +import org.opensearch.timeseries.util.ParseUtils; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; + +/** + * Shared code to implement an entity result transportation + * (e.g., EntityForecastResultTransportAction) + * + */ +public class EntityResultProcessor, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, IndexableResultType extends IndexableResult, ResultWriteRequestType extends ResultWriteRequest, ResultWriteBatchRequestType extends ResultBulkRequest, CheckpointDaoType extends CheckpointDao, CheckpointWriteWorkerType extends CheckpointWriteWorker, ModelColdStartType extends ModelColdStart, ModelManagerType extends ModelManager, CacheType extends TimeSeriesCache, ColdStartWorkerType extends ColdStartWorker, ResultHandlerType extends IndexMemoryPressureAwareResultHandler, ResultWriteWorkerType extends ResultWriteWorker, HCCheckpointReadWorkerType extends CheckpointReadWorker, ColdEntityWorkerType extends ColdEntityWorker> { + + private static final Logger LOG = LogManager.getLogger(EntityResultProcessor.class); + + private CacheProvider cache; + private ModelManagerType modelManager; + private IndexType resultIndex; + private IndexManagementType indexUtil; + private ResultWriteWorkerType resultWriteQueue; + private Class resultWriteClass; + private Stats stats; + private ColdStartWorkerType entityColdStartWorker; + private HCCheckpointReadWorkerType checkpointReadQueue; + private ColdEntityWorkerType coldEntityQueue; + + public EntityResultProcessor( + CacheProvider cache, + ModelManagerType manager, + IndexType resultIndex, + IndexManagementType indexUtil, + ResultWriteWorkerType resultWriteQueue, + Class resultWriteClass, + Stats stats, + ColdStartWorkerType entityColdStartWorker, + HCCheckpointReadWorkerType checkpointReadQueue, + ColdEntityWorkerType coldEntityQueue + ) { + this.cache = cache; + this.modelManager = manager; + this.resultIndex = resultIndex; + this.indexUtil = indexUtil; + this.resultWriteQueue = resultWriteQueue; + this.resultWriteClass = resultWriteClass; + this.stats = stats; + this.entityColdStartWorker = entityColdStartWorker; + this.checkpointReadQueue = checkpointReadQueue; + this.coldEntityQueue = coldEntityQueue; + } + + public ActionListener> onGetConfig( + ActionListener listener, + String forecasterId, + EntityResultRequest request, + Optional prevException + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(forecasterId, "Config " + forecasterId + " is not available.", false)); + return; + } + + Config config = configOptional.get(); + + if (request.getEntities() == null) { + listener.onFailure(new EndRunException(forecasterId, "Fail to get any entities from request.", false)); + return; + } + + Instant executionStartTime = Instant.now(); + Map cacheMissEntities = new HashMap<>(); + for (Entry entityEntry : request.getEntities().entrySet()) { + Entity entity = entityEntry.getKey(); + + if (isEntityFromOldNodeMsg(entity) && config.getCategoryFields() != null && config.getCategoryFields().size() == 1) { + Map attrValues = entity.getAttributes(); + // handle a request from a version before OpenSearch 1.1. + entity = Entity.createSingleAttributeEntity(config.getCategoryFields().get(0), attrValues.get(CommonName.EMPTY_FIELD)); + } + + Optional modelIdOptional = entity.getModelId(forecasterId); + if (false == modelIdOptional.isPresent()) { + continue; + } + + String modelId = modelIdOptional.get(); + double[] datapoint = entityEntry.getValue(); + ModelState entityModel = cache.get().get(modelId, config); + if (entityModel == null) { + // cache miss + cacheMissEntities.put(entity, datapoint); + continue; + } + try { + IntermediateResultType result = modelManager + .getResult( + new Sample(datapoint, Instant.ofEpochMilli(request.getStart()), Instant.ofEpochMilli(request.getEnd())), + entityModel, + modelId, + Optional.ofNullable(entity), + config + ); + // result.getRcfScore() = 0 means the model is not initialized + // result.getGrade() = 0 means it is not an anomaly + // So many OpenSearchRejectedExecutionException if we write no matter what + if (result.getRcfScore() > 0) { + List indexableResults = result + .toIndexableResults( + config, + Instant.ofEpochMilli(request.getStart()), + Instant.ofEpochMilli(request.getEnd()), + executionStartTime, + Instant.now(), + ParseUtils.getFeatureData(datapoint, config), + Optional.ofNullable(entity), + indexUtil.getSchemaVersion(resultIndex), + modelId, + null, + null + ); + + for (IndexableResultType resultToSave : indexableResults) { + resultWriteQueue + .put( + ResultWriteRequest + .create( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + resultToSave, + config.getCustomResultIndex(), + resultWriteClass + ) + ); + } + } + } catch (IllegalArgumentException e) { + // fail to score likely due to model corruption. Re-cold start to recover. + LOG.error(new ParameterizedMessage("Likely model corruption for [{}]", modelId), e); + stats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).increment(); + cache.get().removeModel(forecasterId, modelId); + entityColdStartWorker + .put( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + RequestPriority.MEDIUM, + datapoint, + request.getStart(), + entity + ) + ); + } + } + + // split hot and cold entities + Pair, List> hotColdEntities = cache + .get() + .selectUpdateCandidate(cacheMissEntities.keySet(), forecasterId, config); + + List hotEntityRequests = new ArrayList<>(); + List coldEntityRequests = new ArrayList<>(); + + for (Entity hotEntity : hotColdEntities.getLeft()) { + double[] hotEntityValue = cacheMissEntities.get(hotEntity); + if (hotEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", hotEntity)); + continue; + } + hotEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // hot entities has MEDIUM priority + RequestPriority.MEDIUM, + hotEntityValue, + request.getStart(), + hotEntity + ) + ); + } + + for (Entity coldEntity : hotColdEntities.getRight()) { + double[] coldEntityValue = cacheMissEntities.get(coldEntity); + if (coldEntityValue == null) { + LOG.error(new ParameterizedMessage("feature value should not be null: [{}]", coldEntity)); + continue; + } + coldEntityRequests + .add( + new FeatureRequest( + System.currentTimeMillis() + config.getIntervalInMilliseconds(), + forecasterId, + // cold entities has LOW priority + RequestPriority.LOW, + coldEntityValue, + request.getStart(), + coldEntity + ) + ); + } + + checkpointReadQueue.putAll(hotEntityRequests); + coldEntityQueue.putAll(coldEntityRequests); + + // respond back + if (prevException.isPresent()) { + listener.onFailure(prevException.get()); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + }, exception -> { + LOG + .error( + new ParameterizedMessage( + "fail to get entity's anomaly grade for detector [{}]: start: [{}], end: [{}]", + forecasterId, + request.getStart(), + request.getEnd() + ), + exception + ); + listener.onFailure(exception); + }); + } + + /** + * Whether the received entity comes from an node that doesn't support multi-category fields. + * This can happen during rolling-upgrade or blue/green deployment. + * + * Specifically, when receiving an EntityResultRequest from an incompatible node, + * EntityResultRequest(StreamInput in) gets an String that represents an entity. + * But Entity class requires both an category field name and value. Since we + * don't have access to detector config in EntityResultRequest(StreamInput in), + * we put CommonName.EMPTY_FIELD as the placeholder. In this method, + * we use the same CommonName.EMPTY_FIELD to check if the deserialized entity + * comes from an incompatible node. If it is, we will add the field name back + * as EntityResultTranportAction has access to the detector config object. + * + * @param categoricalValues deserialized Entity from inbound message. + * @return Whether the received entity comes from an node that doesn't support multi-category fields. + */ + private boolean isEntityFromOldNodeMsg(Entity categoricalValues) { + Map attrValues = categoricalValues.getAttributes(); + return (attrValues != null && attrValues.containsKey(CommonName.EMPTY_FIELD)); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java similarity index 78% rename from src/main/java/org/opensearch/ad/transport/EntityResultRequest.java rename to src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java index 91041f447..bab1f4b21 100644 --- a/src/main/java/org/opensearch/ad/transport/EntityResultRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/EntityResultRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -17,12 +17,8 @@ import java.util.Locale; import java.util.Map; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -33,16 +29,16 @@ import org.opensearch.timeseries.model.Entity; public class EntityResultRequest extends ActionRequest implements ToXContentObject { - private static final Logger LOG = LogManager.getLogger(EntityResultRequest.class); - private String detectorId; + protected String configId; // changed from Map to Map - private Map entities; - private long start; - private long end; + protected Map entities; + // data start/end time epoch + protected long start; + protected long end; public EntityResultRequest(StreamInput in) throws IOException { super(in); - this.detectorId = in.readString(); + this.configId = in.readString(); // guarded with version check. Just in case we receive requests from older node where we use String // to represent an entity @@ -52,16 +48,16 @@ public EntityResultRequest(StreamInput in) throws IOException { this.end = in.readLong(); } - public EntityResultRequest(String detectorId, Map entities, long start, long end) { + public EntityResultRequest(String configId, Map entities, long start, long end) { super(); - this.detectorId = detectorId; + this.configId = configId; this.entities = entities; this.start = start; this.end = end; } - public String getId() { - return this.detectorId; + public String getConfigId() { + return this.configId; } public Map getEntities() { @@ -79,7 +75,7 @@ public long getEnd() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(this.detectorId); + out.writeString(this.configId); // guarded with version check. Just in case we send requests to older node where we use String // to represent an entity out.writeMap(entities, (s, e) -> e.writeTo(s), StreamOutput::writeDoubleArray); @@ -91,8 +87,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(detectorId)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } if (start <= 0 || end <= 0 || start > end) { validationException = addValidationError( @@ -106,7 +102,7 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, detectorId); + builder.field(CommonName.CONFIG_ID_KEY, configId); builder.field(CommonName.START_JSON_KEY, start); builder.field(CommonName.END_JSON_KEY, end); builder.startArray(CommonName.ENTITIES_JSON_KEY); diff --git a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java index aef29626d..1aed87c66 100644 --- a/src/main/java/org/opensearch/ad/transport/GetAnomalyDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/GetConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; @@ -19,9 +19,9 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.timeseries.model.Entity; -public class GetAnomalyDetectorRequest extends ActionRequest { +public class GetConfigRequest extends ActionRequest { - private String detectorID; + private String configID; private long version; private boolean returnJob; private boolean returnTask; @@ -30,9 +30,9 @@ public class GetAnomalyDetectorRequest extends ActionRequest { private boolean all; private Entity entity; - public GetAnomalyDetectorRequest(StreamInput in) throws IOException { + public GetConfigRequest(StreamInput in) throws IOException { super(in); - detectorID = in.readString(); + configID = in.readString(); version = in.readLong(); returnJob = in.readBoolean(); returnTask = in.readBoolean(); @@ -44,7 +44,7 @@ public GetAnomalyDetectorRequest(StreamInput in) throws IOException { } } - public GetAnomalyDetectorRequest( + public GetConfigRequest( String detectorID, long version, boolean returnJob, @@ -55,7 +55,7 @@ public GetAnomalyDetectorRequest( Entity entity ) { super(); - this.detectorID = detectorID; + this.configID = detectorID; this.version = version; this.returnJob = returnJob; this.returnTask = returnTask; @@ -65,8 +65,8 @@ public GetAnomalyDetectorRequest( this.entity = entity; } - public String getDetectorID() { - return detectorID; + public String getConfigID() { + return configID; } public long getVersion() { @@ -100,7 +100,7 @@ public Entity getEntity() { @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(detectorID); + out.writeString(configID); out.writeLong(version); out.writeBoolean(returnJob); out.writeBoolean(returnTask); diff --git a/src/main/java/org/opensearch/timeseries/transport/JobRequest.java b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java new file mode 100644 index 000000000..a46767dca --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/JobRequest.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.timeseries.model.DateRange; + +public class JobRequest extends ActionRequest { + + private String configID; + private DateRange dateRange; + private boolean historical; + private String rawPath; + + public JobRequest(StreamInput in) throws IOException { + super(in); + configID = in.readString(); + rawPath = in.readString(); + if (in.readBoolean()) { + dateRange = new DateRange(in); + } + historical = in.readBoolean(); + } + + public JobRequest(String detectorID, String rawPath) { + this(detectorID, null, false, rawPath); + } + + /** + * Constructor function. + * + * The dateRange and historical boolean can be passed in individually. + * The historical flag is for stopping analysis, the dateRange is for + * starting analysis. It's ok if historical is true but dateRange is + * null. + * + * @param configID config identifier + * @param dateRange analysis date range + * @param historical historical analysis or not + * @param seqNo seq no + * @param primaryTerm primary term + * @param rawPath raw request path + */ + public JobRequest(String configID, DateRange dateRange, boolean historical, String rawPath) { + super(); + this.configID = configID; + this.dateRange = dateRange; + this.historical = historical; + this.rawPath = rawPath; + } + + public String getConfigID() { + return configID; + } + + public DateRange getDateRange() { + return dateRange; + } + + public String getRawPath() { + return rawPath; + } + + public boolean isHistorical() { + return historical; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configID); + out.writeString(rawPath); + if (dateRange != null) { + out.writeBoolean(true); + dateRange.writeTo(out); + } else { + out.writeBoolean(false); + } + out.writeBoolean(historical); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/JobResponse.java b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java new file mode 100644 index 000000000..588651bd3 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/JobResponse.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.util.RestHandlerUtils; + +public class JobResponse extends ActionResponse implements ToXContentObject { + private final String id; + + public JobResponse(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public JobResponse(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field(RestHandlerUtils._ID, id).endObject(); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java new file mode 100644 index 000000000..cd8efc9de --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkRequest.java @@ -0,0 +1,88 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; + +public class ResultBulkRequest> extends + ActionRequest + implements + Writeable { + private final List results; + + public ResultBulkRequest() { + results = new ArrayList<>(); + } + + public ResultBulkRequest(StreamInput in, Writeable.Reader reader) throws IOException { + super(in); + int size = in.readVInt(); + results = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + results.add(reader.read(in)); + } + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (results.isEmpty()) { + validationException = ValidateActions.addValidationError(CommonMessages.NO_REQUESTS_ADDED_ERR, validationException); + } + return validationException; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeVInt(results.size()); + for (ResultWriteRequestType result : results) { + result.writeTo(out); + } + } + + /** + * + * @return all of the results to send + */ + public List getAnomalyResults() { + return results; + } + + /** + * Add result to send + * @param resultWriteRequest The result write request + */ + public void add(ResultWriteRequestType resultWriteRequest) { + results.add(resultWriteRequest); + } + + /** + * + * @return total index requests + */ + public int numberOfActions() { + return results.size(); + } +} diff --git a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java similarity index 86% rename from src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java rename to src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java index 8206d908e..5a55d7354 100644 --- a/src/main/java/org/opensearch/ad/transport/ADResultBulkResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.IOException; import java.util.ArrayList; @@ -21,7 +21,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -public class ADResultBulkResponse extends ActionResponse { +public class ResultBulkResponse extends ActionResponse { public static final String RETRY_REQUESTS_JSON_KEY = "retry_requests"; private List retryRequests; @@ -30,15 +30,15 @@ public class ADResultBulkResponse extends ActionResponse { * * @param retryRequests a list of requests to retry */ - public ADResultBulkResponse(List retryRequests) { + public ResultBulkResponse(List retryRequests) { this.retryRequests = retryRequests; } - public ADResultBulkResponse() { + public ResultBulkResponse() { this.retryRequests = null; } - public ADResultBulkResponse(StreamInput in) throws IOException { + public ResultBulkResponse(StreamInput in) throws IOException { int size = in.readInt(); if (size > 0) { retryRequests = new ArrayList<>(size); diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java new file mode 100644 index 000000000..dcb84640d --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultBulkTransportAction.java @@ -0,0 +1,121 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.opensearch.index.IndexingPressure.MAX_INDEXING_BYTES; + +import java.io.IOException; +import java.util.List; +import java.util.Random; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.bulk.BulkAction; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.IndexingPressure; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.util.BulkUtil; +import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.transport.TransportService; + +@SuppressWarnings("rawtypes") +public abstract class ResultBulkTransportAction, ResultBulkRequestType extends ResultBulkRequest> + extends HandledTransportAction { + private static final Logger LOG = LogManager.getLogger(ResultBulkTransportAction.class); + protected IndexingPressure indexingPressure; + private final long primaryAndCoordinatingLimits; + protected float softLimit; + protected float hardLimit; + protected String indexName; + private Client client; + protected Random random; + + public ResultBulkTransportAction( + String actionName, + TransportService transportService, + ActionFilters actionFilters, + IndexingPressure indexingPressure, + Settings settings, + Client client, + float softLimit, + float hardLimit, + String indexName, + Writeable.Reader requestReader + ) { + super(actionName, transportService, actionFilters, requestReader, ThreadPool.Names.SAME); + this.indexingPressure = indexingPressure; + this.primaryAndCoordinatingLimits = MAX_INDEXING_BYTES.get(settings).getBytes(); + this.client = client; + + this.softLimit = softLimit; + this.hardLimit = hardLimit; + this.indexName = indexName; + + // random seed is 42. Can be any number + this.random = new Random(42); + } + + @Override + protected void doExecute(Task task, ResultBulkRequestType request, ActionListener listener) { + // Concurrent indexing memory limit = 10% of heap + // indexing pressure = indexing bytes / indexing limit + // Write all until index pressure (global indexing memory pressure) is less than 80% of 10% of heap. Otherwise, index + // all non-zero anomaly grade index requests and index zero anomaly grade index requests with probability (1 - index pressure). + long totalBytes = indexingPressure.getCurrentCombinedCoordinatingAndPrimaryBytes() + indexingPressure.getCurrentReplicaBytes(); + float indexingPressurePercent = (float) totalBytes / primaryAndCoordinatingLimits; + @SuppressWarnings("rawtypes") + List results = request.getAnomalyResults(); + + if (results == null || results.size() < 1) { + listener.onResponse(new ResultBulkResponse()); + } + + BulkRequest bulkRequest = prepareBulkRequest(indexingPressurePercent, request); + + if (bulkRequest.numberOfActions() > 0) { + client.execute(BulkAction.INSTANCE, bulkRequest, ActionListener.wrap(bulkResponse -> { + List failedRequests = BulkUtil.getFailedIndexRequest(bulkRequest, bulkResponse); + listener.onResponse(new ResultBulkResponse(failedRequests)); + }, e -> { + LOG.error("Failed to bulk index AD result", e); + listener.onFailure(e); + })); + } else { + listener.onResponse(new ResultBulkResponse()); + } + } + + protected abstract BulkRequest prepareBulkRequest(float indexingPressurePercent, ResultBulkRequestType request); + + protected void addResult(BulkRequest bulkRequest, ToXContentObject result, String resultIndex) { + String index = resultIndex == null ? indexName : resultIndex; + try (XContentBuilder builder = jsonBuilder()) { + IndexRequest indexRequest = new IndexRequest(index).source(result.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); + bulkRequest.add(indexRequest); + } catch (IOException e) { + LOG.error("Failed to prepare bulk index request for index " + index, e); + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java new file mode 100644 index 000000000..4f5003b13 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultProcessor.java @@ -0,0 +1,741 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.net.ConnectException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.ExceptionsHelper; +import org.opensearch.OpenSearchTimeoutException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionListenerResponseHandler; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.block.ClusterBlockLevel; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.NetworkExceptionHelper; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.node.NodeClosedException; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; +import org.opensearch.timeseries.cluster.HashRing; +import org.opensearch.timeseries.common.exception.ClientException; +import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.common.exception.InternalFailure; +import org.opensearch.timeseries.common.exception.NotSerializedExceptionName; +import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.feature.CompositeRetriever; +import org.opensearch.timeseries.feature.CompositeRetriever.PageIterator; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.TaskType; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.task.TaskCacheManager; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.ExceptionUtil; +import org.opensearch.timeseries.util.SecurityClientUtil; +import org.opensearch.transport.ActionNotFoundTransportException; +import org.opensearch.transport.ConnectTransportException; +import org.opensearch.transport.NodeNotConnectedException; +import org.opensearch.transport.ReceiveTimeoutTransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.TransportService; + +public abstract class ResultProcessor, TaskCacheManagerType extends TaskCacheManager, TaskTypeEnum extends TaskType, TaskClass extends TimeSeriesTask, IndexType extends Enum & TimeSeriesIndex, IndexManagementType extends IndexManagement, TaskManagerType extends TaskManager> { + + private static final Logger LOG = LogManager.getLogger(ResultProcessor.class); + + static final String WAIT_FOR_THRESHOLD_ERR_MSG = "Exception in waiting for threshold result"; + + static final String NO_ACK_ERR = "no acknowledgements from model hosting nodes."; + + public static final String TROUBLE_QUERYING_ERR_MSG = "Having trouble querying data: "; + + public static final String NULL_RESPONSE = "Received null response from"; + + public static final String INDEX_READ_BLOCKED = "Cannot read user index due to read block."; + + public static final String READ_WRITE_BLOCKED = "Cannot read/write due to global block."; + + public static final String NODE_UNRESPONSIVE_ERR_MSG = "Model node is unresponsive. Mute node"; + + protected final TransportRequestOptions option; + private String entityResultAction; + protected Class transportResultResponseClazz; + private StatNames hcRequestCountStat; + private String threadPoolName; + // within an interval, how many percents are used to process requests. + // 1.0 means we use all of the detection interval to process requests. + // to ensure we don't block next interval, it is better to set it less than 1.0. + private final float intervalRatioForRequest; + private int maxEntitiesPerInterval; + private int pageSize; + protected final ThreadPool threadPool; + private final HashRing hashRing; + protected final NodeStateManager nodeStateManager; + protected final TransportService transportService; + private final Stats timeSeriesStats; + private final TaskManagerType realTimeTaskManager; + private NamedXContentRegistry xContentRegistry; + private final Client client; + private final SecurityClientUtil clientUtil; + private Settings settings; + private final IndexNameExpressionResolver indexNameExpressionResolver; + private final ClusterService clusterService; + protected final FeatureManager featureManager; + protected final AnalysisType context; + + public ResultProcessor( + Setting requestTimeoutSetting, + float intervalRatioForRequests, + String entityResultAction, + StatNames hcRequestCountStat, + Settings settings, + ClusterService clusterService, + ThreadPool threadPool, + String threadPoolName, + HashRing hashRing, + NodeStateManager nodeStateManager, + TransportService transportService, + Stats timeSeriesStats, + TaskManagerType realTimeTaskManager, + NamedXContentRegistry xContentRegistry, + Client client, + SecurityClientUtil clientUtil, + IndexNameExpressionResolver indexNameExpressionResolver, + Class transportResultResponseClazz, + FeatureManager featureManager, + Setting maxEntitiesPerIntervalSetting, + Setting pageSizeSetting, + AnalysisType context + ) { + this.option = TransportRequestOptions + .builder() + .withType(TransportRequestOptions.Type.REG) + .withTimeout(requestTimeoutSetting.get(settings)) + .build(); + this.intervalRatioForRequest = intervalRatioForRequests; + + this.maxEntitiesPerInterval = maxEntitiesPerIntervalSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(maxEntitiesPerIntervalSetting, it -> maxEntitiesPerInterval = it); + + this.pageSize = pageSizeSetting.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(pageSizeSetting, it -> pageSize = it); + + this.entityResultAction = entityResultAction; + this.hcRequestCountStat = hcRequestCountStat; + this.threadPool = threadPool; + this.hashRing = hashRing; + this.nodeStateManager = nodeStateManager; + this.transportService = transportService; + this.timeSeriesStats = timeSeriesStats; + this.realTimeTaskManager = realTimeTaskManager; + this.xContentRegistry = xContentRegistry; + this.client = client; + this.clientUtil = clientUtil; + this.settings = settings; + this.indexNameExpressionResolver = indexNameExpressionResolver; + this.clusterService = clusterService; + this.transportResultResponseClazz = transportResultResponseClazz; + this.featureManager = featureManager; + this.context = context; + this.threadPoolName = threadPoolName; + } + + /** + * didn't use ActionListener.wrap so that I can + * 1) use this to refer to the listener inside the listener + * 2) pass parameters using constructors + * + */ + class PageListener implements ActionListener { + private PageIterator pageIterator; + private String configId; + private long dataStartTime; + private long dataEndTime; + + PageListener(PageIterator pageIterator, String detectorId, long dataStartTime, long dataEndTime) { + this.pageIterator = pageIterator; + this.configId = detectorId; + this.dataStartTime = dataStartTime; + this.dataEndTime = dataEndTime; + } + + @Override + public void onResponse(CompositeRetriever.Page entityFeatures) { + if (pageIterator.hasNext()) { + pageIterator.next(this); + } + if (entityFeatures != null && false == entityFeatures.isEmpty()) { + // wrap expensive operation inside ad threadpool + threadPool.executor(threadPoolName).execute(() -> { + try { + + Set>> node2Entities = entityFeatures + .getResults() + .entrySet() + .stream() + .filter(e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).isPresent()) + .collect( + Collectors + .groupingBy( + // from entity name to its node + e -> hashRing.getOwningNodeWithSameLocalVersionForRealtime(e.getKey().toString()).get(), + Collectors.toMap(Entry::getKey, Entry::getValue) + ) + ) + .entrySet(); + + Iterator>> iterator = node2Entities.iterator(); + + while (iterator.hasNext()) { + Entry> entry = iterator.next(); + DiscoveryNode modelNode = entry.getKey(); + if (modelNode == null) { + iterator.remove(); + continue; + } + String modelNodeId = modelNode.getId(); + if (nodeStateManager.isMuted(modelNodeId, configId)) { + LOG + .info( + String + .format( + Locale.ROOT, + ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for detector %s", + modelNodeId, + configId + ) + ); + iterator.remove(); + } + } + + final AtomicReference failure = new AtomicReference<>(); + node2Entities.stream().forEach(nodeEntity -> { + DiscoveryNode node = nodeEntity.getKey(); + transportService + .sendRequest( + node, + entityResultAction, + new EntityResultRequest(configId, nodeEntity.getValue(), dataStartTime, dataEndTime), + option, + new ActionListenerResponseHandler<>( + new ErrorResponseListener(node.getId(), configId, failure), + AcknowledgedResponse::new, + ThreadPool.Names.SAME + ) + ); + }); + + } catch (Exception e) { + LOG.error("Unexpected exception", e); + handleException(e); + } + }); + } + } + + @Override + public void onFailure(Exception e) { + LOG.error("Unexpetected exception", e); + handleException(e); + } + + private void handleException(Exception e) { + Exception convertedException = convertedQueryFailureException(e, configId); + if (false == (convertedException instanceof TimeSeriesException)) { + Throwable cause = ExceptionsHelper.unwrapCause(convertedException); + convertedException = new InternalFailure(configId, cause); + } + nodeStateManager.setException(configId, convertedException); + } + } + + public ActionListener> onGetConfig( + ActionListener listener, + String configID, + TransportResultRequestType request, + Set hcDetectors + ) { + return ActionListener.wrap(configOptional -> { + if (!configOptional.isPresent()) { + listener.onFailure(new EndRunException(configID, "config is not available.", true)); + return; + } + + Config config = configOptional.get(); + if (config.isHighCardinality()) { + hcDetectors.add(configID); + timeSeriesStats.getStat(hcRequestCountStat.getName()).increment(); + } + + long delayMillis = Optional + .ofNullable((IntervalTimeConfiguration) config.getWindowDelay()) + .map(t -> t.toDuration().toMillis()) + .orElse(0L); + long dataStartTime = request.getStart() - delayMillis; + long dataEndTime = request.getEnd() - delayMillis; + + realTimeTaskManager + .initCacheWithCleanupIfRequired( + configID, + config, + transportService, + ActionListener + .runAfter( + initRealtimeTaskCacheListener(configID), + () -> executeAnalysis(listener, configID, request, config, dataStartTime, dataEndTime) + ) + ); + }, exception -> ResultProcessor.handleExecuteException(exception, listener, configID)); + } + + private ActionListener initRealtimeTaskCacheListener(String configId) { + return ActionListener.wrap(r -> { + if (r) { + LOG.debug("Realtime task cache initied for config {}", configId); + } + }, e -> LOG.error("Failed to init realtime task cache for " + configId, e)); + } + + private void executeAnalysis( + ActionListener listener, + String configID, + ResultRequest request, + Config config, + long dataStartTime, + long dataEndTime + ) { + // HC logic starts here + if (config.isHighCardinality()) { + Optional previousException = nodeStateManager.fetchExceptionAndClear(configID); + if (previousException.isPresent()) { + Exception exception = previousException.get(); + LOG.error(new ParameterizedMessage("Previous exception of [{}]", configID), exception); + if (exception instanceof EndRunException) { + EndRunException endRunException = (EndRunException) exception; + if (endRunException.isEndNow()) { + listener.onFailure(exception); + return; + } + } + } + + // assume request are in epoch milliseconds + long nextDetectionStartTime = request.getEnd() + (long) (config.getIntervalInMilliseconds() * intervalRatioForRequest); + + CompositeRetriever compositeRetriever = new CompositeRetriever( + dataStartTime, + dataEndTime, + config, + xContentRegistry, + client, + clientUtil, + nextDetectionStartTime, + settings, + maxEntitiesPerInterval, + pageSize, + indexNameExpressionResolver, + clusterService, + context + ); + + PageIterator pageIterator = null; + + try { + pageIterator = compositeRetriever.iterator(); + } catch (Exception e) { + listener.onFailure(new EndRunException(config.getId(), CommonMessages.INVALID_SEARCH_QUERY_MSG, e, false)); + return; + } + + PageListener getEntityFeatureslistener = new PageListener(pageIterator, configID, dataStartTime, dataEndTime); + if (pageIterator.hasNext()) { + pageIterator.next(getEntityFeatureslistener); + } + + // We don't know when the pagination will finish. To not + // block the following interval request to start, we return immediately. + // Pagination will stop itself when the time is up. + if (previousException.isPresent()) { + listener.onFailure(previousException.get()); + } else { + listener + .onResponse( + ResultResponse + .create( + new ArrayList(), + null, + null, + config.getIntervalInMinutes(), + true, + transportResultResponseClazz + ) + ); + } + return; + } + + // HC logic ends and single entity logic starts here + // We are going to use only 1 model partition for a single stream detector. + // That's why we use 0 here. + String rcfModelID = SingleStreamModelIdMapper.getRcfModelId(configID, 0); + Optional asRCFNode = hashRing.getOwningNodeWithSameLocalVersionForRealtime(rcfModelID); + if (!asRCFNode.isPresent()) { + listener.onFailure(new InternalFailure(configID, "RCF model node is not available.")); + return; + } + + DiscoveryNode rcfNode = asRCFNode.get(); + + if (!shouldStart(listener, configID, config, rcfNode.getId(), rcfModelID)) { + return; + } + + featureManager + .getCurrentFeatures( + config, + dataStartTime, + dataEndTime, + onFeatureResponseForSingleStreamConfig(configID, config, listener, rcfModelID, rcfNode, dataStartTime, dataEndTime) + ); + } + + protected abstract ActionListener onFeatureResponseForSingleStreamConfig( + String configId, + Config config, + ActionListener listener, + String rcfModelId, + DiscoveryNode rcfNode, + long dataStartTime, + long dataEndTime + ); + + protected void handleQueryFailure(Exception exception, ActionListener listener, String adID) { + Exception convertedQueryFailureException = convertedQueryFailureException(exception, adID); + + if (convertedQueryFailureException instanceof EndRunException) { + // invalid feature query + listener.onFailure(convertedQueryFailureException); + } else { + ResultProcessor.handleExecuteException(convertedQueryFailureException, listener, adID); + } + } + + /** + * Convert a query related exception to EndRunException + * + * These query exception can happen during the starting phase of the OpenSearch + * process. Thus, set the stopNow parameter of these EndRunException to false + * and confirm the EndRunException is not a false positive. + * + * @param exception Exception + * @param adID detector Id + * @return the converted exception if the exception is query related + */ + private Exception convertedQueryFailureException(Exception exception, String adID) { + if (ExceptionUtil.isIndexNotAvailable(exception)) { + return new EndRunException(adID, ResultProcessor.TROUBLE_QUERYING_ERR_MSG + exception.getMessage(), false) + .countedInStats(false); + } else if (exception instanceof SearchPhaseExecutionException && invalidQuery((SearchPhaseExecutionException) exception)) { + // This is to catch invalid aggregation on wrong field type. For example, + // sum aggregation on text field. We should end detector run for such case. + return new EndRunException( + adID, + CommonMessages.INVALID_SEARCH_QUERY_MSG + " " + ((SearchPhaseExecutionException) exception).getDetailedMessage(), + exception, + false + ).countedInStats(false); + } + + return exception; + } + + protected void findException(Throwable cause, String configID, AtomicReference failure, String nodeId) { + if (cause == null) { + LOG.error(new ParameterizedMessage("Null input exception")); + return; + } + if (cause instanceof Error) { + // we cannot do anything with Error. + LOG.error(new ParameterizedMessage("Error during prediction for {}: ", configID), cause); + return; + } + + Exception causeException = (Exception) cause; + + if (causeException instanceof TimeSeriesException) { + failure.set(causeException); + } else if (causeException instanceof NotSerializableExceptionWrapper) { + // we only expect this happens on AD exceptions + Optional actualException = NotSerializedExceptionName + .convertWrappedTimeSeriesException((NotSerializableExceptionWrapper) causeException, configID); + if (actualException.isPresent()) { + TimeSeriesException adException = actualException.get(); + failure.set(adException); + if (adException instanceof ResourceNotFoundException) { + // During a rolling upgrade or blue/green deployment, ResourceNotFoundException might be caused by old node using RCF + // 1.0 + // cannot recognize new checkpoint produced by the coordinating node using compact RCF. Add pressure to mute the node + // after consecutive failures. + nodeStateManager.addPressure(nodeId, configID); + } + } else { + // some unexpected bugs occur while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } else if (causeException instanceof OpenSearchTimeoutException) { + // we can have OpenSearchTimeoutException when a node tries to load RCF or + // threshold model + failure.set(new InternalFailure(configID, causeException)); + } else if (causeException instanceof IllegalArgumentException) { + // we can have IllegalArgumentException when a model is corrupted + failure.set(new InternalFailure(configID, causeException)); + } else { + // some unexpected bug occurred or cluster is unstable (e.g., ClusterBlockException) or index is red (e.g. + // NoShardAvailableActionException) while predicting anomaly + failure.set(new EndRunException(configID, CommonMessages.BUG_RESPONSE, causeException, false)); + } + } + + private boolean invalidQuery(SearchPhaseExecutionException ex) { + // If all shards return bad request and failure cause is IllegalArgumentException, we + // consider the feature query is invalid and will not count the error in failure stats. + for (ShardSearchFailure failure : ex.shardFailures()) { + if (RestStatus.BAD_REQUEST != failure.status() || !(failure.getCause() instanceof IllegalArgumentException)) { + return false; + } + } + return true; + } + + /** + * Handle a prediction failure. Possibly (i.e., we don't always need to do that) + * convert the exception to a form that AD can recognize and handle and sets the + * input failure reference to the converted exception. + * + * @param e prediction exception + * @param adID Detector Id + * @param nodeID Node Id + * @param failure Parameter to receive the possibly converted function for the + * caller to deal with + */ + protected void handlePredictionFailure(Exception e, String adID, String nodeID, AtomicReference failure) { + LOG.error(new ParameterizedMessage("Received an error from node {} while doing model inference for {}", nodeID, adID), e); + if (e == null) { + return; + } + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (hasConnectionIssue(cause)) { + handleConnectionException(nodeID, adID); + } else { + findException(cause, adID, failure, nodeID); + } + } + + /** + * Check if the input exception indicates connection issues. + * During blue-green deployment, we may see ActionNotFoundTransportException. + * Count that as connection issue and isolate that node if it continues to happen. + * + * @param e exception + * @return true if we get disconnected from the node or the node is not in the + * right state (being closed) or transport request times out (sent from TimeoutHandler.run) + */ + private boolean hasConnectionIssue(Throwable e) { + return e instanceof ConnectTransportException + || e instanceof NodeClosedException + || e instanceof ReceiveTimeoutTransportException + || e instanceof NodeNotConnectedException + || e instanceof ConnectException + || NetworkExceptionHelper.isCloseConnectionException(e) + || e instanceof ActionNotFoundTransportException; + } + + private void handleConnectionException(String node, String detectorId) { + final DiscoveryNodes nodes = clusterService.state().nodes(); + if (!nodes.nodeExists(node)) { + hashRing.buildCirclesForRealtime(); + return; + } + // rebuilding is not done or node is unresponsive + nodeStateManager.addPressure(node, detectorId); + } + + /** + * Since we need to read from customer index and write to anomaly result index, + * we need to make sure we can read and write. + * + * @param state Cluster state + * @return whether we have global block or not + */ + private boolean checkGlobalBlock(ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.READ) != null + || state.blocks().globalBlockedException(ClusterBlockLevel.WRITE) != null; + } + + /** + * Similar to checkGlobalBlock, we check block on the indices level. + * + * @param state Cluster state + * @param level block level + * @param indices the indices on which to check block + * @return whether any of the index has block on the level. + */ + private boolean checkIndicesBlocked(ClusterState state, ClusterBlockLevel level, String... indices) { + // the original index might be an index expression with wildcards like "log*", + // so we need to expand the expression to concrete index name + String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(state, IndicesOptions.lenientExpandOpen(), indices); + + return state.blocks().indicesBlockedException(level, concreteIndices) != null; + } + + /** + * Check if we should start anomaly prediction. + * + * @param listener listener to respond back to AnomalyResultRequest. + * @param adID detector ID + * @param detector detector instance corresponds to adID + * @param rcfNodeId the rcf model hosting node ID for adID + * @param rcfModelID the rcf model ID for adID + * @return if we can start anomaly prediction. + */ + private boolean shouldStart( + ActionListener listener, + String adID, + Config detector, + String rcfNodeId, + String rcfModelID + ) { + ClusterState state = clusterService.state(); + if (checkGlobalBlock(state)) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.READ_WRITE_BLOCKED)); + return false; + } + + if (nodeStateManager.isMuted(rcfNodeId, adID)) { + listener + .onFailure( + new InternalFailure( + adID, + String + .format(Locale.ROOT, ResultProcessor.NODE_UNRESPONSIVE_ERR_MSG + " %s for rcf model %s", rcfNodeId, rcfModelID) + ) + ); + return false; + } + + if (checkIndicesBlocked(state, ClusterBlockLevel.READ, detector.getIndices().toArray(new String[0]))) { + listener.onFailure(new InternalFailure(adID, ResultProcessor.INDEX_READ_BLOCKED)); + return false; + } + + return true; + } + + public static void handleExecuteException(Exception ex, ActionListener listener, String id) { + if (ex instanceof ClientException) { + listener.onFailure(ex); + } else if (ex instanceof TimeSeriesException) { + listener.onFailure(new InternalFailure((TimeSeriesException) ex)); + } else { + Throwable cause = ExceptionsHelper.unwrapCause(ex); + listener.onFailure(new InternalFailure(id, cause)); + } + } + + public class ErrorResponseListener implements ActionListener { + private String nodeId; + private final String configId; + private AtomicReference failure; + + public ErrorResponseListener(String nodeId, String configId, AtomicReference failure) { + this.nodeId = nodeId; + this.configId = configId; + this.failure = failure; + } + + @Override + public void onResponse(AcknowledgedResponse response) { + try { + if (response.isAcknowledged() == false) { + LOG.error("Cannot send entities' features to {} for {}", nodeId, configId); + nodeStateManager.addPressure(nodeId, configId); + } else { + nodeStateManager.resetBackpressureCounter(nodeId, configId); + } + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + @Override + public void onFailure(Exception e) { + try { + // e.g., we have connection issues with all of the nodes while restarting clusters + LOG.error(new ParameterizedMessage("Cannot send entities' features to {} for {}", nodeId, configId), e); + + handleException(e); + + } catch (Exception ex) { + LOG.error("Unexpected exception: {} for {}", ex, configId); + handleException(ex); + } + } + + private void handleException(Exception e) { + handlePredictionFailure(e, configId, nodeId, failure); + if (failure.get() != null) { + nodeStateManager.setException(configId, failure.get()); + } + } + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java new file mode 100644 index 000000000..5227703e4 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultRequest.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; + +public abstract class ResultRequest extends ActionRequest implements ToXContentObject { + protected String configId; + // time range start and end. Unit: epoch milliseconds + protected long start; + protected long end; + + public ResultRequest(StreamInput in) throws IOException { + super(in); + configId = in.readString(); + start = in.readLong(); + end = in.readLong(); + } + + public ResultRequest(String configID, long start, long end) { + super(); + this.configId = configID; + this.start = start; + this.end = end; + } + + public long getStart() { + return start; + } + + public long getEnd() { + return end; + } + + public String getConfigId() { + return configId; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(configId); + out.writeLong(start); + out.writeLong(end); + } +} diff --git a/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java new file mode 100644 index 000000000..604089e3f --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/ResultResponse.java @@ -0,0 +1,118 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.ActionResponse; +import org.opensearch.ad.transport.AnomalyResultResponse; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.forecast.transport.ForecastResultResponse; +import org.opensearch.timeseries.model.FeatureData; +import org.opensearch.timeseries.model.IndexableResult; + +public abstract class ResultResponse extends ActionResponse implements ToXContentObject { + + protected String error; + protected List features; + protected Long rcfTotalUpdates; + protected Long configIntervalInMinutes; + protected Boolean isHC; + + public ResultResponse(List features, String error, Long rcfTotalUpdates, Long configInterval, Boolean isHC) { + this.error = error; + this.features = features; + this.rcfTotalUpdates = rcfTotalUpdates; + this.configIntervalInMinutes = configInterval; + this.isHC = isHC; + } + + /** + * Create an empty result response or when in an erroneous state. + * @param + * @param error + * @param features + * @param rcfTotalUpdates + * @param configInterval + * @param isHC + * @param clazz + * @return + */ + public static > T create( + List features, + String error, + Long rcfTotalUpdates, + Long configInterval, + Boolean isHC, + Class clazz + ) { + if (clazz.isAssignableFrom(AnomalyResultResponse.class)) { + return clazz.cast(new AnomalyResultResponse(features, error, rcfTotalUpdates, configInterval, isHC)); + } else if (clazz.isAssignableFrom(ForecastResultResponse.class)) { + return clazz.cast(new ForecastResultResponse(features, error, rcfTotalUpdates, configInterval, isHC)); + } else { + throw new IllegalArgumentException("Unsupported result response type"); + } + } + + /** + * Leave it as implementation detail in subclass as how to deserialize TimeSeriesResultResponse + * @param in deserialization stream + * @throws IOException when deserialization errs + */ + public ResultResponse(StreamInput in) throws IOException { + super(in); + } + + public String getError() { + return error; + } + + public List getFeatures() { + return features; + } + + public Long getRcfTotalUpdates() { + return rcfTotalUpdates; + } + + public Long getConfigIntervalInMinutes() { + return configIntervalInMinutes; + } + + public Boolean isHC() { + return isHC; + } + + /** + * + * @return whether we should save the response to result index + */ + public boolean shouldSave() { + return error != null; + } + + public abstract List toIndexableResults( + String configId, + Instant dataStartInstant, + Instant dataEndInstant, + Instant executionStartInstant, + Instant executionEndInstant, + Integer schemaVersion, + User user, + String error + ); +} diff --git a/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java new file mode 100644 index 000000000..c5a7d32c5 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/SingleStreamResultRequest.java @@ -0,0 +1,115 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Locale; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.Strings; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; + +public class SingleStreamResultRequest extends ActionRequest implements ToXContentObject { + private final String configId; + private final String modelId; + + // data start/end time epoch in milliseconds + private final long startMillis; + private final long endMillis; + private final double[] datapoint; + + public SingleStreamResultRequest(String configId, String modelId, long start, long end, double[] datapoint) { + super(); + this.configId = configId; + this.modelId = modelId; + this.startMillis = start; + this.endMillis = end; + this.datapoint = datapoint; + } + + public SingleStreamResultRequest(StreamInput in) throws IOException { + super(in); + this.configId = in.readString(); + this.modelId = in.readString(); + this.startMillis = in.readLong(); + this.endMillis = in.readLong(); + this.datapoint = in.readDoubleArray(); + } + + public String getConfigId() { + return this.configId; + } + + public String getModelId() { + return modelId; + } + + public long getStart() { + return this.startMillis; + } + + public long getEnd() { + return this.endMillis; + } + + public double[] getDataPoint() { + return this.datapoint; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.configId); + out.writeString(this.modelId); + out.writeLong(this.startMillis); + out.writeLong(this.endMillis); + out.writeDoubleArray(datapoint); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CommonName.CONFIG_ID_KEY, configId); + builder.field(CommonName.MODEL_ID_KEY, modelId); + builder.field(CommonName.START_JSON_KEY, startMillis); + builder.field(CommonName.END_JSON_KEY, endMillis); + builder.array(CommonName.VALUE_LIST_FIELD, datapoint); + builder.endObject(); + return builder; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (Strings.isEmpty(configId)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); + } + if (Strings.isEmpty(modelId)) { + validationException = addValidationError(CommonMessages.MODEL_ID_MISSING_MSG, validationException); + } + if (startMillis <= 0 || endMillis <= 0 || startMillis > endMillis) { + validationException = addValidationError( + String.format(Locale.ROOT, "%s: start %d, end %d", CommonMessages.INVALID_TIMESTAMP_ERR_MSG, startMillis, endMillis), + validationException + ); + } + return validationException; + } +} diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java similarity index 64% rename from src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java index 71563a2cd..da70786a3 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorRequest.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigRequest.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import static org.opensearch.action.ValidateActions.addValidationError; @@ -19,8 +19,6 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequestValidationException; -import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.constant.ADCommonName; import org.opensearch.core.common.Strings; import org.opensearch.core.common.io.stream.InputStreamStreamInput; import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; @@ -28,43 +26,45 @@ import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; -public class StopDetectorRequest extends ActionRequest implements ToXContentObject { +public class StopConfigRequest extends ActionRequest implements ToXContentObject { - private String adID; + private String configID; - public StopDetectorRequest() {} + public StopConfigRequest() {} - public StopDetectorRequest(StreamInput in) throws IOException { + public StopConfigRequest(StreamInput in) throws IOException { super(in); - this.adID = in.readString(); + this.configID = in.readString(); } - public StopDetectorRequest(String adID) { + public StopConfigRequest(String configID) { super(); - this.adID = adID; + this.configID = configID; } - public String getAdID() { - return adID; + public String getConfigID() { + return configID; } - public StopDetectorRequest adID(String adID) { - this.adID = adID; + public StopConfigRequest adID(String configID) { + this.configID = configID; return this; } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(adID); + out.writeString(configID); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isEmpty(adID)) { - validationException = addValidationError(ADCommonMessages.AD_ID_MISSING_MSG, validationException); + if (Strings.isEmpty(configID)) { + validationException = addValidationError(CommonMessages.CONFIG_ID_MISSING_MSG, validationException); } return validationException; } @@ -72,20 +72,20 @@ public ActionRequestValidationException validate() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ADCommonName.ID_JSON_KEY, adID); + builder.field(CommonName.CONFIG_ID_KEY, configID); builder.endObject(); return builder; } - public static StopDetectorRequest fromActionRequest(final ActionRequest actionRequest) { - if (actionRequest instanceof StopDetectorRequest) { - return (StopDetectorRequest) actionRequest; + public static StopConfigRequest fromActionRequest(final ActionRequest actionRequest) { + if (actionRequest instanceof StopConfigRequest) { + return (StopConfigRequest) actionRequest; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionRequest.writeTo(osso); try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorRequest(input); + return new StopConfigRequest(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionRequest into StopDetectorRequest", e); diff --git a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java similarity index 78% rename from src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java rename to src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java index b3606b918..39e0b503d 100644 --- a/src/main/java/org/opensearch/ad/transport/StopDetectorResponse.java +++ b/src/main/java/org/opensearch/timeseries/transport/StopConfigResponse.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.transport; +package org.opensearch.timeseries.transport; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -23,15 +23,15 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -public class StopDetectorResponse extends ActionResponse implements ToXContentObject { +public class StopConfigResponse extends ActionResponse implements ToXContentObject { public static final String SUCCESS_JSON_KEY = "success"; private boolean success; - public StopDetectorResponse(boolean success) { + public StopConfigResponse(boolean success) { this.success = success; } - public StopDetectorResponse(StreamInput in) throws IOException { + public StopConfigResponse(StreamInput in) throws IOException { super(in); success = in.readBoolean(); } @@ -53,15 +53,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public static StopDetectorResponse fromActionResponse(final ActionResponse actionResponse) { - if (actionResponse instanceof StopDetectorResponse) { - return (StopDetectorResponse) actionResponse; + public static StopConfigResponse fromActionResponse(final ActionResponse actionResponse) { + if (actionResponse instanceof StopConfigResponse) { + return (StopConfigResponse) actionResponse; } try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { actionResponse.writeTo(osso); try (InputStreamStreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { - return new StopDetectorResponse(input); + return new StopConfigResponse(input); } } catch (IOException e) { throw new IllegalArgumentException("failed to parse ActionResponse into StopDetectorResponse", e); diff --git a/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java new file mode 100644 index 000000000..70de6e683 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/transport/handler/IndexMemoryPressureAwareResultHandler.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.transport.handler; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.ActionListener; +import org.opensearch.client.Client; +import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; + +/** + * Different from ResultIndexingHandler and ResultBulkIndexingHandler, this class uses + * customized transport action to bulk index results. These transport action will + * reduce traffic when index memory pressure is high. + * + * + * @param Batch request type + * @param Batch response type + * @param forecasting or AD result index + * @param Index management class + */ +public abstract class IndexMemoryPressureAwareResultHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(IndexMemoryPressureAwareResultHandler.class); + + protected final Client client; + protected final IndexManagementType timeSeriesIndices; + + public IndexMemoryPressureAwareResultHandler(Client client, IndexManagementType timeSeriesIndices) { + this.client = client; + this.timeSeriesIndices = timeSeriesIndices; + } + + /** + * Execute the bulk request + * @param currentBulkRequest The bulk request + * @param listener callback after flushing + */ + public void flush(BatchRequestType currentBulkRequest, ActionListener listener) { + try { + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime + // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin + // recreate it, that may bring confusion. + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(initResponse -> { + if (initResponse.isAcknowledged()) { + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Creating result index with mappings call not acknowledged."); + listener.onFailure(new TimeSeriesException("", "Creating result index with mappings call not acknowledged.")); + } + }, exception -> { + if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { + // It is possible the index has been created while we sending the create request + bulk(currentBulkRequest, listener); + } else { + LOG.warn("Unexpected error creating result index", exception); + listener.onFailure(exception); + } + })); + } else { + bulk(currentBulkRequest, listener); + } + } catch (Exception e) { + LOG.warn("Error in bulking results", e); + listener.onFailure(e); + } + } + + public abstract void bulk(BatchRequestType currentBulkRequest, ActionListener listener); +} diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java similarity index 54% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java index c021ead73..417b73dbc 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultBulkIndexingHandler.java @@ -9,11 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; -import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.List; @@ -25,98 +23,121 @@ import org.opensearch.action.bulk.BulkRequestBuilder; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyResultBulkIndexHandler extends AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyResultBulkIndexHandler.class); +/** + * + * Utility method to bulk index results + * + */ +public class ResultBulkIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> + extends ResultIndexingHandler { - private ADIndexManagement anomalyDetectionIndices; + private static final Logger LOG = LogManager.getLogger(ResultBulkIndexingHandler.class); - public AnomalyResultBulkIndexHandler( + public ResultBulkIndexingHandler( Client client, Settings settings, ThreadPool threadPool, + String indexName, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, ClusterService clusterService, - ADIndexManagement anomalyDetectionIndices + Setting backOffDelaySetting, + Setting maxRetrySetting ) { - super(client, settings, threadPool, ANOMALY_RESULT_INDEX_ALIAS, anomalyDetectionIndices, clientUtil, indexUtils, clusterService); - this.anomalyDetectionIndices = anomalyDetectionIndices; + super( + client, + settings, + threadPool, + indexName, + timeSeriesIndices, + clientUtil, + indexUtils, + clusterService, + backOffDelaySetting, + maxRetrySetting + ); } /** - * Bulk index anomaly results. Create anomaly result index first if it doesn't exist. + * Bulk index results. Create result index first if it doesn't exist. * - * @param resultIndex anomaly result index - * @param anomalyResults anomaly results + * @param resultIndex result index + * @param results results to save + * @param configId Config Id * @param listener action listener */ - public void bulkIndexAnomalyResult(String resultIndex, List anomalyResults, ActionListener listener) { - if (anomalyResults == null || anomalyResults.size() == 0) { + public void bulk(String resultIndex, List results, String configId, ActionListener listener) { + if (results == null || results.size() == 0) { listener.onResponse(null); return; } - String detectorId = anomalyResults.get(0).getConfigId(); + try { if (resultIndex != null) { - // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime + // Only create custom result index when creating detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(resultIndex)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); + if (!timeSeriesIndices.doesIndexExist(resultIndex)) { + throw new EndRunException(configId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + resultIndex, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(resultIndex)) { - throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); + if (!timeSeriesIndices.isValidResultIndexMapping(resultIndex)) { + throw new EndRunException(configId, "wrong index mapping of custom result index", true); } - bulkSaveDetectorResult(resultIndex, anomalyResults, listener); + bulk(resultIndex, results, listener); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices.initDefaultResultIndexDirectly(ActionListener.wrap(response -> { if (response.isAcknowledged()) { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { - String error = "Creating anomaly result index with mappings call not acknowledged"; + String error = "Creating result index with mappings call not acknowledged"; LOG.error(error); listener.onFailure(new TimeSeriesException(error)); } }, exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { // It is possible the index has been created while we sending the create request - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } else { listener.onFailure(exception); } })); } else { - bulkSaveDetectorResult(anomalyResults, listener); + bulk(results, listener); } } catch (TimeSeriesException e) { listener.onFailure(e); } catch (Exception e) { - String error = "Failed to bulk index anomaly result"; + String error = "Failed to bulk index result"; LOG.error(error, e); listener.onFailure(new TimeSeriesException(error, e)); } } - private void bulkSaveDetectorResult(List anomalyResults, ActionListener listener) { - bulkSaveDetectorResult(ANOMALY_RESULT_INDEX_ALIAS, anomalyResults, listener); + private void bulk(List anomalyResults, ActionListener listener) { + bulk(defaultResultIndexName, anomalyResults, listener); } - private void bulkSaveDetectorResult(String resultIndex, List anomalyResults, ActionListener listener) { + private void bulk(String resultIndex, List anomalyResults, ActionListener listener) { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); anomalyResults.forEach(anomalyResult -> { try (XContentBuilder builder = jsonBuilder()) { @@ -124,7 +145,7 @@ private void bulkSaveDetectorResult(String resultIndex, List anom .source(anomalyResult.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); bulkRequestBuilder.add(indexRequest); } catch (Exception e) { - String error = "Failed to prepare request to bulk index anomaly results"; + String error = "Failed to prepare request to bulk index results"; LOG.error(error, e); throw new TimeSeriesException(error); } @@ -132,16 +153,15 @@ private void bulkSaveDetectorResult(String resultIndex, List anom client.bulk(bulkRequestBuilder.request(), ActionListener.wrap(r -> { if (r.hasFailures()) { String failureMessage = r.buildFailureMessage(); - LOG.warn("Failed to bulk index AD result " + failureMessage); + LOG.warn("Failed to bulk index result " + failureMessage); listener.onFailure(new TimeSeriesException(failureMessage)); } else { listener.onResponse(r); } }, e -> { - LOG.error("bulk index ad result failed", e); + LOG.error("bulk index result failed", e); listener.onFailure(e); })); } - } diff --git a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java similarity index 77% rename from src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java rename to src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java index 371640ad2..49dec72d3 100644 --- a/src/main/java/org/opensearch/ad/transport/handler/AnomalyIndexHandler.java +++ b/src/main/java/org/opensearch/timeseries/transport/handler/ResultIndexingHandler.java @@ -9,10 +9,9 @@ * GitHub history for details. */ -package org.opensearch.ad.transport.handler; +package org.opensearch.timeseries.transport.handler; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; -import static org.opensearch.timeseries.constant.CommonMessages.CAN_NOT_FIND_RESULT_INDEX; import java.util.Iterator; import java.util.Locale; @@ -26,26 +25,28 @@ import org.opensearch.action.bulk.BackoffPolicy; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.BulkUtil; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.client.Client; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.indices.IndexManagement; +import org.opensearch.timeseries.indices.TimeSeriesIndex; +import org.opensearch.timeseries.model.IndexableResult; +import org.opensearch.timeseries.util.BulkUtil; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.timeseries.util.RestHandlerUtils; -public class AnomalyIndexHandler { - private static final Logger LOG = LogManager.getLogger(AnomalyIndexHandler.class); +public class ResultIndexingHandler & TimeSeriesIndex, IndexManagementType extends IndexManagement> { + private static final Logger LOG = LogManager.getLogger(ResultIndexingHandler.class); static final String FAIL_TO_SAVE_ERR_MSG = "Fail to save %s: "; static final String SUCCESS_SAVING_MSG = "Succeed in saving %s"; static final String CANNOT_SAVE_ERR_MSG = "Cannot save %s due to write block."; @@ -55,8 +56,8 @@ public class AnomalyIndexHandler { protected final ThreadPool threadPool; protected final BackoffPolicy savingBackoffPolicy; - protected final String indexName; - protected final ADIndexManagement anomalyDetectionIndices; + protected final String defaultResultIndexName; + protected final IndexManagementType timeSeriesIndices; // whether save to a specific doc id or not. False by default. protected boolean fixedDoc; protected final ClientUtil clientUtil; @@ -70,30 +71,28 @@ public class AnomalyIndexHandler { * @param settings accessor for node settings. * @param threadPool used to invoke specific threadpool to execute * @param indexName name of index to save to - * @param anomalyDetectionIndices anomaly detection indices + * @param timeSeriesIndices anomaly detection indices * @param clientUtil client wrapper * @param indexUtils Index util classes * @param clusterService accessor to ES cluster service */ - public AnomalyIndexHandler( + public ResultIndexingHandler( Client client, Settings settings, ThreadPool threadPool, String indexName, - ADIndexManagement anomalyDetectionIndices, + IndexManagementType timeSeriesIndices, ClientUtil clientUtil, IndexUtils indexUtils, - ClusterService clusterService + ClusterService clusterService, + Setting backOffDelaySetting, + Setting maxRetrySetting ) { this.client = client; this.threadPool = threadPool; - this.savingBackoffPolicy = BackoffPolicy - .exponentialBackoff( - AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), - AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings) - ); - this.indexName = indexName; - this.anomalyDetectionIndices = anomalyDetectionIndices; + this.savingBackoffPolicy = BackoffPolicy.exponentialBackoff(backOffDelaySetting.get(settings), maxRetrySetting.get(settings)); + this.defaultResultIndexName = indexName; + this.timeSeriesIndices = timeSeriesIndices; this.fixedDoc = false; this.clientUtil = clientUtil; this.indexUtils = indexUtils; @@ -111,8 +110,8 @@ public void setFixedDoc(boolean fixedDoc) { } // TODO: check if user has permission to index. - public void index(T toSave, String detectorId, String customIndexName) { - if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.indexName)) { + public void index(ResultType toSave, String detectorId, String customIndexName) { + if (indexUtils.checkIndicesBlocked(clusterService.state(), ClusterBlockLevel.WRITE, this.defaultResultIndexName)) { LOG.warn(String.format(Locale.ROOT, CANNOT_SAVE_ERR_MSG, detectorId)); return; } @@ -122,17 +121,17 @@ public void index(T toSave, String detectorId, String customIndexName) { // Only create custom AD result index when create detector, won’t recreate custom AD result index in realtime // job and historical analysis later if it’s deleted. If user delete the custom AD result index, and AD plugin // recreate it, that may bring confusion. - if (!anomalyDetectionIndices.doesIndexExist(customIndexName)) { - throw new EndRunException(detectorId, CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); + if (!timeSeriesIndices.doesIndexExist(customIndexName)) { + throw new EndRunException(detectorId, CommonMessages.CAN_NOT_FIND_RESULT_INDEX + customIndexName, true); } - if (!anomalyDetectionIndices.isValidResultIndexMapping(customIndexName)) { + if (!timeSeriesIndices.isValidResultIndexMapping(customIndexName)) { throw new EndRunException(detectorId, "wrong index mapping of custom AD result index", true); } save(toSave, detectorId, customIndexName); return; } - if (!anomalyDetectionIndices.doesDefaultResultIndexExist()) { - anomalyDetectionIndices + if (!timeSeriesIndices.doesDefaultResultIndexExist()) { + timeSeriesIndices .initDefaultResultIndexDirectly( ActionListener.wrap(initResponse -> onCreateIndexResponse(initResponse, toSave, detectorId), exception -> { if (ExceptionsHelper.unwrapCause(exception) instanceof ResourceAlreadyExistsException) { @@ -141,7 +140,7 @@ public void index(T toSave, String detectorId, String customIndexName) { } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Unexpected error creating index %s", indexName), + String.format(Locale.ROOT, "Unexpected error creating index %s", defaultResultIndexName), exception ); } @@ -153,32 +152,32 @@ public void index(T toSave, String detectorId, String customIndexName) { } catch (Exception e) { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Error in saving %s for detector %s", indexName, detectorId), + String.format(Locale.ROOT, "Error in saving %s for detector %s", defaultResultIndexName, detectorId), e ); } } - private void onCreateIndexResponse(CreateIndexResponse response, T toSave, String detectorId) { + private void onCreateIndexResponse(CreateIndexResponse response, ResultType toSave, String detectorId) { if (response.isAcknowledged()) { save(toSave, detectorId); } else { throw new TimeSeriesException( detectorId, - String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", indexName) + String.format(Locale.ROOT, "Creating %s with mappings call not acknowledged.", defaultResultIndexName) ); } } - protected void save(T toSave, String detectorId) { - save(toSave, detectorId, indexName); + protected void save(ResultType toSave, String detectorId) { + save(toSave, detectorId, defaultResultIndexName); } // TODO: Upgrade custom result index mapping to latest version? // It may bring some issue if we upgrade the custom result index mapping while user is using that index // for other use cases. One easy solution is to tell user only use custom result index for AD plugin. // For the first release of custom result index, it's not a issue. Will leave this to next phase. - protected void save(T toSave, String detectorId, String indexName) { + protected void save(ResultType toSave, String detectorId, String indexName) { try (XContentBuilder builder = jsonBuilder()) { IndexRequest indexRequest = new IndexRequest(indexName).source(toSave.toXContent(builder, RestHandlerUtils.XCONTENT_WITH_TYPE)); if (fixedDoc) { @@ -192,14 +191,14 @@ protected void save(T toSave, String detectorId, String indexName) { } } - void saveIteration(IndexRequest indexRequest, String detectorId, Iterator backoff) { + void saveIteration(IndexRequest indexRequest, String configId, Iterator backoff) { clientUtil .asyncRequest( indexRequest, client::index, ActionListener .wrap( - response -> { LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, detectorId)); }, + response -> { LOG.debug(String.format(Locale.ROOT, SUCCESS_SAVING_MSG, configId)); }, exception -> { // OpenSearch has a thread pool and a queue for write per node. A thread // pool will have N number of workers ready to handle the requests. When a @@ -215,13 +214,13 @@ void saveIteration(IndexRequest indexRequest, String detectorId, Iterator saveIteration(BulkUtil.cloneIndexRequest(indexRequest), detectorId, backoff), + () -> saveIteration(BulkUtil.cloneIndexRequest(indexRequest), configId, backoff), nextDelay, ThreadPool.Names.SAME ); diff --git a/src/main/java/org/opensearch/ad/util/BulkUtil.java b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java similarity index 98% rename from src/main/java/org/opensearch/ad/util/BulkUtil.java rename to src/main/java/org/opensearch/timeseries/util/BulkUtil.java index d7fe9c6f6..c2b275a1f 100644 --- a/src/main/java/org/opensearch/ad/util/BulkUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/BulkUtil.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.ArrayList; import java.util.HashSet; diff --git a/src/main/java/org/opensearch/timeseries/util/ClientUtil.java b/src/main/java/org/opensearch/timeseries/util/ClientUtil.java new file mode 100644 index 000000000..011819a07 --- /dev/null +++ b/src/main/java/org/opensearch/timeseries/util/ClientUtil.java @@ -0,0 +1,71 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.timeseries.util; + +import java.util.function.BiConsumer; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; + +public class ClientUtil { + private Client client; + + @Inject + public ClientUtil(Client client) { + this.client = client; + } + + /** + * Send an asynchronous request and handle response with the provided listener. + * @param ActionRequest + * @param ActionResponse + * @param request request body + * @param consumer request method, functional interface to operate as a client request like client::get + * @param listener needed to handle response + */ + public void asyncRequest( + Request request, + BiConsumer> consumer, + ActionListener listener + ) { + consumer + .accept( + request, + ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) + ); + } + + /** + * Execute a transport action and handle response with the provided listener. + * @param ActionRequest + * @param ActionResponse + * @param action transport action + * @param request request body + * @param listener needed to handle response + */ + public void execute( + ActionType action, + Request request, + ActionListener listener + ) { + client + .execute( + action, + request, + ActionListener.wrap(response -> { listener.onResponse(response); }, exception -> { listener.onFailure(exception); }) + ); + } +} diff --git a/src/main/java/org/opensearch/ad/util/DateUtils.java b/src/main/java/org/opensearch/timeseries/util/DateUtils.java similarity index 96% rename from src/main/java/org/opensearch/ad/util/DateUtils.java rename to src/main/java/org/opensearch/timeseries/util/DateUtils.java index e7cfc21ce..a76fc5bcb 100644 --- a/src/main/java/org/opensearch/ad/util/DateUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/DateUtils.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.time.Duration; import java.time.Instant; diff --git a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java b/src/main/java/org/opensearch/timeseries/util/ExceptionUtil.java similarity index 99% rename from src/main/java/org/opensearch/ad/util/ExceptionUtil.java rename to src/main/java/org/opensearch/timeseries/util/ExceptionUtil.java index b48cf49e4..d71204c82 100644 --- a/src/main/java/org/opensearch/ad/util/ExceptionUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/ExceptionUtil.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.EnumSet; import java.util.concurrent.RejectedExecutionException; diff --git a/src/main/java/org/opensearch/ad/util/IndexUtils.java b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java similarity index 78% rename from src/main/java/org/opensearch/ad/util/IndexUtils.java rename to src/main/java/org/opensearch/timeseries/util/IndexUtils.java index b69c0924a..de11b66fa 100644 --- a/src/main/java/org/opensearch/ad/util/IndexUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/IndexUtils.java @@ -9,18 +9,14 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.List; import java.util.Locale; -import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.action.admin.indices.stats.IndicesStatsRequest; -import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; import org.opensearch.action.support.IndicesOptions; -import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.block.ClusterBlockLevel; import org.opensearch.cluster.health.ClusterIndexHealth; @@ -43,8 +39,6 @@ public class IndexUtils { private static final Logger logger = LogManager.getLogger(IndexUtils.class); - private Client client; - private ClientUtil clientUtil; private ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; @@ -57,14 +51,7 @@ public class IndexUtils { * @param indexNameExpressionResolver index name resolver */ @Inject - public IndexUtils( - Client client, - ClientUtil clientUtil, - ClusterService clusterService, - IndexNameExpressionResolver indexNameExpressionResolver - ) { - this.client = client; - this.clientUtil = clientUtil; + public IndexUtils(ClusterService clusterService, IndexNameExpressionResolver indexNameExpressionResolver) { this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; } @@ -110,25 +97,6 @@ public String getIndexHealthStatus(String indexOrAliasName) throws IllegalArgume return indexHealth.getStatus().name().toLowerCase(Locale.ROOT); } - /** - * Gets the number of documents in an index. - * - * @deprecated - * - * @param indexName Name of the index - * @return The number of documents in an index. 0 is returned if the index does not exist. -1 is returned if the - * request fails. - */ - @Deprecated - public Long getNumberOfDocumentsInIndex(String indexName) { - if (!clusterService.state().getRoutingTable().hasIndex(indexName)) { - return 0L; - } - IndicesStatsRequest indicesStatsRequest = new IndicesStatsRequest(); - Optional response = clientUtil.timedRequest(indicesStatsRequest, logger, client.admin().indices()::stats); - return response.map(r -> r.getIndex(indexName).getPrimaries().docs.getCount()).orElse(-1L); - } - /** * Similar to checkGlobalBlock, we check block on the indices level. * diff --git a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java similarity index 98% rename from src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java rename to src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java index 8b18bf9c3..fe9ea1cc2 100644 --- a/src/main/java/org/opensearch/ad/util/MultiResponsesDelegateActionListener.java +++ b/src/main/java/org/opensearch/timeseries/util/MultiResponsesDelegateActionListener.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.ArrayList; import java.util.Collections; diff --git a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java index ee73be777..76d9beddc 100644 --- a/src/main/java/org/opensearch/timeseries/util/ParseUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/ParseUtils.java @@ -17,7 +17,6 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.search.aggregations.AggregationBuilders.dateRange; import static org.opensearch.search.aggregations.AggregatorFactories.VALID_AGG_NAME; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.settings.TimeSeriesSettings.MAX_BATCH_TASK_PIECE_SIZE; import java.io.IOException; @@ -38,11 +37,10 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.join.ScoreMode; import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionResponse; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchResponse; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.xcontent.LoggingDeprecationHandler; @@ -73,7 +71,9 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.FeatureData; @@ -303,23 +303,23 @@ public static AggregatorFactories.Builder parseAggregators(XContentParser parser } public static SearchSourceBuilder generateInternalFeatureQuery( - AnomalyDetector detector, + Config config, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) .format("epoch_millis") .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); SearchSourceBuilder internalSearchSourceBuilder = new SearchSourceBuilder().query(internalFilterQuery); - if (detector.getFeatureAttributes() != null) { - for (Feature feature : detector.getFeatureAttributes()) { + if (config.getFeatureAttributes() != null) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), xContentRegistry, @@ -333,18 +333,18 @@ public static SearchSourceBuilder generateInternalFeatureQuery( } public static SearchSourceBuilder generatePreviewQuery( - AnomalyDetector detector, + Config config, List> ranges, NamedXContentRegistry xContentRegistry ) throws IOException { - DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); + DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(config.getTimeField()).format("epoch_millis"); for (Entry range : ranges) { dateRangeBuilder.addRange(range.getKey(), range.getValue()); } - if (detector.getFeatureAttributes() != null) { - for (Feature feature : detector.getFeatureAttributes()) { + if (config.getFeatureAttributes() != null) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), xContentRegistry, @@ -354,29 +354,31 @@ public static SearchSourceBuilder generatePreviewQuery( } } - return new SearchSourceBuilder().query(detector.getFilterQuery()).size(0).aggregation(dateRangeBuilder); + return new SearchSourceBuilder().query(config.getFilterQuery()).size(0).aggregation(dateRangeBuilder); } - public static SearchSourceBuilder generateEntityColdStartQuery( - AnomalyDetector detector, + public static SearchSourceBuilder generateColdStartQuery( + Config config, List> ranges, - Entity entity, + Optional entity, NamedXContentRegistry xContentRegistry ) throws IOException { - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().filter(config.getFilterQuery()); - for (TermQueryBuilder term : entity.getTermQueryBuilders()) { - internalFilterQuery.filter(term); + if (entity.isPresent()) { + for (TermQueryBuilder term : entity.get().getTermQueryBuilders()) { + internalFilterQuery.filter(term); + } } - DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(detector.getTimeField()).format("epoch_millis"); + DateRangeAggregationBuilder dateRangeBuilder = dateRange("date_range").field(config.getTimeField()).format("epoch_millis"); for (Entry range : ranges) { dateRangeBuilder.addRange(range.getKey(), range.getValue()); } - if (detector.getFeatureAttributes() != null) { - for (Feature feature : detector.getFeatureAttributes()) { + if (config.getFeatureAttributes() != null) { + for (Feature feature : config.getFeatureAttributes()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), xContentRegistry, @@ -392,12 +394,12 @@ public static SearchSourceBuilder generateEntityColdStartQuery( /** * Map feature data to its Id and name * @param currentFeature Feature data - * @param detector Detector Config object + * @param config Config object * @return a list of feature data with Id and name */ - public static List getFeatureData(double[] currentFeature, AnomalyDetector detector) { - List featureIds = detector.getEnabledFeatureIds(); - List featureNames = detector.getEnabledFeatureNames(); + public static List getFeatureData(double[] currentFeature, Config config) { + List featureIds = config.getEnabledFeatureIds(); + List featureNames = config.getEnabledFeatureNames(); int featureLen = featureIds.size(); List featureData = new ArrayList<>(); for (int i = 0; i < featureLen; i++) { @@ -443,23 +445,51 @@ public static User getUserContext(Client client) { return User.parse(userStr); } - public static void resolveUserAndExecute( + /** + * run the given function based on given user + * @param Config response type. Can be either GetAnomalyDetectorResponse or GetForecasterResponse + * @param requestedUser requested user + * @param configId config Id + * @param filterByEnabled filter by backend is enabled + * @param listener listener. We didn't provide the generic type of listener and therefore can return anything using the listener. + * @param function Function to execute + * @param client Client to OS. + * @param clusterService Cluster service of OS. + * @param xContentRegistry Used to deserialize the get config response. + * @param getConfigResponseClass The class of Get config transport response. Based on the class type, we run different parse method + * on the get response. + * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config + */ + public static void resolveUserAndExecute( User requestedUser, - String detectorId, + String configId, boolean filterByEnabled, ActionListener listener, - Consumer function, + Consumer function, Client client, ClusterService clusterService, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + Class getConfigResponseClass, + Class configTypeClass ) { try { - if (requestedUser == null || detectorId == null) { + if (requestedUser == null || configId == null) { // requestedUser == null means security is disabled or user is superadmin. In this case we don't need to // check if request user have access to the detector or not. function.accept(null); } else { - getDetector(requestedUser, detectorId, listener, function, client, clusterService, xContentRegistry, filterByEnabled); + getConfig( + requestedUser, + configId, + listener, + function, + client, + clusterService, + xContentRegistry, + filterByEnabled, + getConfigResponseClass, + configTypeClass + ); } } catch (Exception e) { listener.onFailure(e); @@ -470,42 +500,49 @@ public static void resolveUserAndExecute( * If filterByEnabled is true, get detector and check if the user has permissions to access the detector, * then execute function; otherwise, get detector and execute function * @param requestUser user from request - * @param detectorId detector id + * @param configId config id * @param listener action listener * @param function consumer function * @param client client * @param clusterService cluster service * @param xContentRegistry XContent registry * @param filterByBackendRole filter by backend role or not + * @param getConfigResponseClass The class of Get config transport response. Based on the class type, we run different parse method + * on the get response. + * @param configTypeClass the class of the ConfigType, used by the ConfigFactory to parse the correct type of Config */ - public static void getDetector( + public static void getConfig( User requestUser, - String detectorId, - ActionListener listener, - Consumer function, + String configId, + ActionListener listener, + Consumer function, Client client, ClusterService clusterService, NamedXContentRegistry xContentRegistry, - boolean filterByBackendRole + boolean filterByBackendRole, + Class getConfigResponseClass, + Class configTypeClass ) { if (clusterService.state().metadata().indices().containsKey(CommonName.CONFIG_INDEX)) { - GetRequest request = new GetRequest(CommonName.CONFIG_INDEX).id(detectorId); + GetRequest request = new GetRequest(CommonName.CONFIG_INDEX).id(configId); client .get( request, ActionListener .wrap( - response -> onGetAdResponse( + response -> onGetConfigResponse( response, requestUser, - detectorId, + configId, listener, function, xContentRegistry, - filterByBackendRole + filterByBackendRole, + getConfigResponseClass, + configTypeClass ), exception -> { - logger.error("Failed to get anomaly detector: " + detectorId, exception); + logger.error("Failed to get config: " + configId, exception); listener.onFailure(exception); } ) @@ -515,34 +552,55 @@ public static void getDetector( } } - public static void onGetAdResponse( + /** + * This method processes the GetResponse and applies the consumer function if the user has permissions + * or if the filterByBackendRole is disabled. It uses a ConfigFactory to parse the correct type of Config. + * + * @param The type of Config to be processed in this method, which extends from the Config base type. + * @param The type of ActionResponse to be used, which extends from the ActionResponse base type. + * @param response The GetResponse from the getConfig request. This contains the information about the config that is to be processed. + * @param requestUser The User from the request. This user's permissions will be checked to ensure they have access to the config. + * @param configId The ID of the config. This is used for logging and error messages. + * @param listener The ActionListener to call if an error occurs. Any errors that occur during the processing of the config will be passed to this listener. + * @param function The Consumer function to apply to the ConfigType. If the user has permission to access the config, this function will be applied. + * @param xContentRegistry The XContentRegistry used to create the XContentParser. This is used to parse the response into a ConfigType. + * @param filterByBackendRole A boolean indicating whether to filter by backend role. If true, the user's backend roles will be checked to ensure they have access to the config. + * @param getConfigResponseClass The class of Get config transport response. Based on the class type, we run different parse method + * on the get response. + * @param configTypeClass The class of the ConfigType, used by the ConfigFactory to parse the correct type of Config. + */ + public static void onGetConfigResponse( GetResponse response, User requestUser, - String detectorId, - ActionListener listener, - Consumer function, + String configId, + ActionListener listener, + Consumer function, NamedXContentRegistry xContentRegistry, - boolean filterByBackendRole + boolean filterByBackendRole, + Class getConfigResponseClass, + Class configTypeClass ) { if (response.isExists()) { try ( XContentParser parser = RestHandlerUtils.createXContentParserFromRegistry(xContentRegistry, response.getSourceAsBytesRef()) ) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - AnomalyDetector detector = AnomalyDetector.parse(parser); + @SuppressWarnings("unchecked") + ConfigType detector = (ConfigType) Config.parseConfig(configTypeClass, parser); + User resourceUser = detector.getUser(); - if (!filterByBackendRole || checkUserPermissions(requestUser, resourceUser, detectorId) || isAdmin(requestUser)) { + if (!filterByBackendRole || checkUserPermissions(requestUser, resourceUser, configId) || isAdmin(requestUser)) { function.accept(detector); } else { - logger.debug("User: " + requestUser.getName() + " does not have permissions to access detector: " + detectorId); - listener.onFailure(new TimeSeriesException(NO_PERMISSION_TO_ACCESS_DETECTOR + detectorId)); + logger.debug("User: " + requestUser.getName() + " does not have permissions to access detector: " + configId); + listener.onFailure(new TimeSeriesException(NO_PERMISSION_TO_ACCESS_DETECTOR + configId)); } } catch (Exception e) { - listener.onFailure(new TimeSeriesException(FAIL_TO_GET_USER_INFO + detectorId)); + listener.onFailure(new TimeSeriesException(FAIL_TO_GET_USER_INFO + configId)); } } else { - listener.onFailure(new ResourceNotFoundException(detectorId, FAIL_TO_FIND_CONFIG_MSG + detectorId)); + listener.onFailure(new ResourceNotFoundException(configId, CommonMessages.FAIL_TO_FIND_CONFIG_MSG + configId)); } } @@ -558,7 +616,7 @@ public static boolean isAdmin(User user) { return user.getRoles().contains("all_access"); } - private static boolean checkUserPermissions(User requestedUser, User resourceUser, String detectorId) throws Exception { + private static boolean checkUserPermissions(User requestedUser, User resourceUser, String configId) throws Exception { if (resourceUser.getBackendRoles() == null || requestedUser.getBackendRoles() == null) { return false; } @@ -571,8 +629,8 @@ private static boolean checkUserPermissions(User requestedUser, User resourceUse + requestedUser.getName() + " has backend role: " + backendRole - + " permissions to access detector: " - + detectorId + + " permissions to access config: " + + configId ); return true; } @@ -613,7 +671,7 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { /** * Generate batch query request for feature aggregation on given date range. * - * @param detector anomaly detector + * @param config config accessor * @param entity entity * @param startTime start time * @param endTime end time @@ -623,46 +681,46 @@ public static Optional getLatestDataTime(SearchResponse searchResponse) { * @throws TimeSeriesException throw AD exception if no enabled feature */ public static SearchSourceBuilder batchFeatureQuery( - AnomalyDetector detector, + Config config, Entity entity, long startTime, long endTime, NamedXContentRegistry xContentRegistry ) throws IOException { - RangeQueryBuilder rangeQuery = new RangeQueryBuilder(detector.getTimeField()) + RangeQueryBuilder rangeQuery = new RangeQueryBuilder(config.getTimeField()) .from(startTime) .to(endTime) .format(EPOCH_MILLIS_FORMAT) .includeLower(true) .includeUpper(false); - BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(detector.getFilterQuery()); + BoolQueryBuilder internalFilterQuery = QueryBuilders.boolQuery().must(rangeQuery).must(config.getFilterQuery()); - if (detector.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { + if (config.isHighCardinality() && entity != null && entity.getAttributes().size() > 0) { entity .getAttributes() .entrySet() .forEach(attr -> { internalFilterQuery.filter(new TermQueryBuilder(attr.getKey(), attr.getValue())); }); } - long intervalSeconds = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().getSeconds(); + long intervalSeconds = ((IntervalTimeConfiguration) config.getInterval()).toDuration().getSeconds(); List> sources = new ArrayList<>(); sources .add( new DateHistogramValuesSourceBuilder(CommonName.DATE_HISTOGRAM) - .field(detector.getTimeField()) + .field(config.getTimeField()) .fixedInterval(DateHistogramInterval.seconds((int) intervalSeconds)) ); CompositeAggregationBuilder aggregationBuilder = new CompositeAggregationBuilder(CommonName.FEATURE_AGGS, sources) .size(MAX_BATCH_TASK_PIECE_SIZE); - if (detector.getEnabledFeatureIds().size() == 0) { + if (config.getEnabledFeatureIds().size() == 0) { throw new TimeSeriesException("No enabled feature configured").countedInStats(false); } - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { if (feature.getEnabled()) { AggregatorFactories.Builder internalAgg = parseAggregators( feature.getAggregation().toString(), @@ -738,9 +796,9 @@ public static List parseAggregationRequest(XContentParser parser) throws return fieldNames; } - public static List getFeatureFieldNames(AnomalyDetector detector, NamedXContentRegistry xContentRegistry) throws IOException { + public static List getFeatureFieldNames(Config config, NamedXContentRegistry xContentRegistry) throws IOException { List featureFields = new ArrayList<>(); - for (Feature feature : detector.getFeatureAttributes()) { + for (Feature feature : config.getFeatureAttributes()) { featureFields.add(getFieldNamesForFeature(feature, xContentRegistry).get(0)); } return featureFields; diff --git a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java index 73ef78aef..f84c22cbb 100644 --- a/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/timeseries/util/RestHandlerUtils.java @@ -17,6 +17,7 @@ import java.io.IOException; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import org.apache.commons.lang.ArrayUtils; @@ -45,7 +46,9 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import com.google.common.base.Throwables; @@ -63,11 +66,8 @@ public final class RestHandlerUtils { public static final String _PRIMARY_TERM = "_primary_term"; public static final String IF_PRIMARY_TERM = "if_primary_term"; public static final String REFRESH = "refresh"; - public static final String DETECTOR_ID = "detectorID"; public static final String RESULT_INDEX = "resultIndex"; - public static final String ANOMALY_DETECTOR = "anomaly_detector"; - public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; - public static final String REALTIME_TASK = "realtime_detection_task"; + public static final String REALTIME_TASK = "realtime_task"; public static final String HISTORICAL_ANALYSIS_TASK = "historical_analysis_task"; public static final String RUN = "_run"; public static final String PREVIEW = "_preview"; @@ -82,13 +82,20 @@ public final class RestHandlerUtils { public static final String TOP_ANOMALIES = "_topAnomalies"; public static final String VALIDATE = "_validate"; public static final ToXContent.MapParams XCONTENT_WITH_TYPE = new ToXContent.MapParams(ImmutableMap.of("with_type", "true")); + public static final String REST_STATUS = "rest_status"; public static final String OPENSEARCH_DASHBOARDS_USER_AGENT = "OpenSearch Dashboards"; public static final String[] UI_METADATA_EXCLUDE = new String[] { Config.UI_METADATA_FIELD }; + // AD constants + public static final String DETECTOR_ID = "detectorID"; + public static final String ANOMALY_DETECTOR = "anomaly_detector"; + public static final String ANOMALY_DETECTOR_JOB = "anomaly_detector_job"; + + // forecast constants public static final String FORECASTER_ID = "forecasterID"; public static final String FORECASTER = "forecaster"; - public static final String REST_STATUS = "rest_status"; + public static final String FORECASTER_JOB = "forecaster_job"; private RestHandlerUtils() {} @@ -247,4 +254,32 @@ public static boolean isProperExceptionToReturn(Throwable e) { private static String coalesceToEmpty(@Nullable String s) { return s == null ? "" : s; } + + public static Entity buildEntity(RestRequest request, String detectorId) throws IOException { + if (org.opensearch.core.common.Strings.isEmpty(detectorId)) { + throw new IllegalStateException(CommonMessages.CONFIG_ID_MISSING_MSG); + } + + String entityName = request.param(CommonName.CATEGORICAL_FIELD); + String entityValue = request.param(CommonName.ENTITY_KEY); + + if (entityName != null && entityValue != null) { + // single-stream profile request: + // GET + // _plugins/_anomaly_detection/detectors//_profile/init_progress?category_field=&entity= + return Entity.createSingleAttributeEntity(entityName, entityValue); + } else if (request.hasContent()) { + /* + * HCAD profile request: GET + * _plugins/_anomaly_detection/detectors//_profile/init_progress { + * "entity": [{ "name": "clientip", "value": "13.24.0.0" }] } + */ + Optional entity = Entity.fromJsonObject(request.contentParser()); + if (entity.isPresent()) { + return entity.get(); + } + } + // not a valid profile request with correct entity information + return null; + } } diff --git a/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java b/src/main/java/org/opensearch/timeseries/util/SafeSecurityInjector.java similarity index 98% rename from src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java rename to src/main/java/org/opensearch/timeseries/util/SafeSecurityInjector.java index 612ea4d5c..671aa0466 100644 --- a/src/main/java/org/opensearch/ad/util/SafeSecurityInjector.java +++ b/src/main/java/org/opensearch/timeseries/util/SafeSecurityInjector.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.List; import java.util.Locale; diff --git a/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java b/src/main/java/org/opensearch/timeseries/util/SecurityClientUtil.java similarity index 82% rename from src/main/java/org/opensearch/ad/util/SecurityClientUtil.java rename to src/main/java/org/opensearch/timeseries/util/SecurityClientUtil.java index 8e9b97b57..82d72e1ab 100644 --- a/src/main/java/org/opensearch/ad/util/SecurityClientUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/SecurityClientUtil.java @@ -9,7 +9,7 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.function.BiConsumer; @@ -17,12 +17,13 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; -import org.opensearch.ad.NodeStateManager; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; public class SecurityClientUtil { private static final String INJECTION_ID = "direct"; @@ -51,12 +52,21 @@ public void asy BiConsumer> consumer, String detectorId, Client client, + AnalysisType context, ActionListener listener ) { ThreadContext threadContext = client.threadPool().getThreadContext(); - try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(detectorId, settings, threadContext, nodeStateManager)) { + try ( + TimeSeriesSafeSecurityInjector injectSecurity = new TimeSeriesSafeSecurityInjector( + detectorId, + settings, + threadContext, + nodeStateManager, + context + ) + ) { injectSecurity - .injectUserRolesFromDetector( + .injectUserRolesFromConfig( ActionListener .wrap( success -> consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())), @@ -82,6 +92,7 @@ public void asy BiConsumer> consumer, User user, Client client, + AnalysisType context, ActionListener listener ) { ThreadContext threadContext = client.threadPool().getThreadContext(); @@ -95,7 +106,15 @@ public void asy // client.execute/client.search and handles the responses (this can be a thread in the search thread pool). // Auto-close in try will restore the context in one thread; the explicit close injectSecurity will restore // the context in another thread. So we still need to put the injectSecurity inside try. - try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + try ( + TimeSeriesSafeSecurityInjector injectSecurity = new TimeSeriesSafeSecurityInjector( + INJECTION_ID, + settings, + threadContext, + nodeStateManager, + context + ) + ) { injectSecurity.injectUserRoles(user); consumer.accept(request, ActionListener.runBefore(listener, () -> injectSecurity.close())); } @@ -117,12 +136,21 @@ public void exe Request request, User user, Client client, + AnalysisType context, ActionListener listener ) { ThreadContext threadContext = client.threadPool().getThreadContext(); // use a hardcoded string as detector id that is only used in logging - try (ADSafeSecurityInjector injectSecurity = new ADSafeSecurityInjector(INJECTION_ID, settings, threadContext, nodeStateManager)) { + try ( + TimeSeriesSafeSecurityInjector injectSecurity = new TimeSeriesSafeSecurityInjector( + INJECTION_ID, + settings, + threadContext, + nodeStateManager, + context + ) + ) { injectSecurity.injectUserRoles(user); client.execute(action, request, ActionListener.runBefore(listener, () -> injectSecurity.close())); } diff --git a/src/main/java/org/opensearch/ad/util/SecurityUtil.java b/src/main/java/org/opensearch/timeseries/util/SecurityUtil.java similarity index 86% rename from src/main/java/org/opensearch/ad/util/SecurityUtil.java rename to src/main/java/org/opensearch/timeseries/util/SecurityUtil.java index d72d345ab..135116c34 100644 --- a/src/main/java/org/opensearch/ad/util/SecurityUtil.java +++ b/src/main/java/org/opensearch/timeseries/util/SecurityUtil.java @@ -9,15 +9,15 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.Collections; import java.util.List; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.common.settings.Settings; import org.opensearch.commons.authuser.User; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; import com.google.common.collect.ImmutableList; @@ -57,12 +57,12 @@ private static User getAdjustedUserBWC(User userObj, Settings settings) { /** * * - * @param detector Detector config + * @param config analysis config * @param settings Node settings * @return user recorded by a detector. Made adjstument for BWC (backward-compatibility) if necessary. */ - public static User getUserFromDetector(AnomalyDetector detector, Settings settings) { - return getAdjustedUserBWC(detector.getUser(), settings); + public static User getUserFromConfig(Config config, Settings settings) { + return getAdjustedUserBWC(config.getUser(), settings); } /** @@ -71,7 +71,7 @@ public static User getUserFromDetector(AnomalyDetector detector, Settings settin * @param settings Node settings * @return user recorded by a detector job */ - public static User getUserFromJob(AnomalyDetectorJob detectorJob, Settings settings) { + public static User getUserFromJob(Job detectorJob, Settings settings) { return getAdjustedUserBWC(detectorJob.getUser(), settings); } } diff --git a/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java b/src/main/java/org/opensearch/timeseries/util/TimeSeriesSafeSecurityInjector.java similarity index 55% rename from src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java rename to src/main/java/org/opensearch/timeseries/util/TimeSeriesSafeSecurityInjector.java index 749a7434c..fcb07c39a 100644 --- a/src/main/java/org/opensearch/ad/util/ADSafeSecurityInjector.java +++ b/src/main/java/org/opensearch/timeseries/util/TimeSeriesSafeSecurityInjector.java @@ -9,31 +9,40 @@ * GitHub history for details. */ -package org.opensearch.ad.util; +package org.opensearch.timeseries.util; import java.util.Optional; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionListener; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.Strings; +import org.opensearch.timeseries.AnalysisType; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.common.exception.EndRunException; +import org.opensearch.timeseries.model.Config; -public class ADSafeSecurityInjector extends SafeSecurityInjector { - private static final Logger LOG = LogManager.getLogger(ADSafeSecurityInjector.class); +public class TimeSeriesSafeSecurityInjector extends SafeSecurityInjector { + private static final Logger LOG = LogManager.getLogger(TimeSeriesSafeSecurityInjector.class); private NodeStateManager nodeStateManager; + private AnalysisType context; - public ADSafeSecurityInjector(String detectorId, Settings settings, ThreadContext tc, NodeStateManager stateManager) { - super(detectorId, settings, tc); + public TimeSeriesSafeSecurityInjector( + String configId, + Settings settings, + ThreadContext tc, + NodeStateManager stateManager, + AnalysisType context + ) { + super(configId, settings, tc); this.nodeStateManager = stateManager; + this.context = context; } - public void injectUserRolesFromDetector(ActionListener injectListener) { + public void injectUserRolesFromConfig(ActionListener injectListener) { // if id is null, we cannot fetch a detector if (Strings.isEmpty(id)) { LOG.debug("Empty id"); @@ -48,21 +57,21 @@ public void injectUserRolesFromDetector(ActionListener injectListener) { return; } - ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { + ActionListener> getDetectorListener = ActionListener.wrap(detectorOp -> { if (!detectorOp.isPresent()) { - injectListener.onFailure(new EndRunException(id, "AnomalyDetector is not available.", false)); + injectListener.onFailure(new EndRunException(id, "Config is not available.", false)); return; } - AnomalyDetector detector = detectorOp.get(); - User userInfo = SecurityUtil.getUserFromDetector(detector, settings); + Config detector = detectorOp.get(); + User userInfo = SecurityUtil.getUserFromConfig(detector, settings); inject(userInfo.getName(), userInfo.getRoles()); injectListener.onResponse(null); }, injectListener::onFailure); - // Since we are gonna read user from detector, make sure the anomaly detector exists and fetched from disk or cached memory - // We don't accept a passed-in AnomalyDetector because the caller might mistakenly not insert any user info in the - // constructed AnomalyDetector and thus poses risks. In the case, if the user is null, we will give admin role. - nodeStateManager.getAnomalyDetector(id, getDetectorListener); + // Since we are gonna read user from config, make sure the config exists and fetched from disk or cached memory + // We don't accept a passed-in Config because the caller might mistakenly not insert any user info in the + // constructed Config and thus poses risks. In the case, if the user is null, we will give admin role. + nodeStateManager.getConfig(id, context, getDetectorListener); } public void injectUserRoles(User user) { diff --git a/src/main/resources/mappings/anomaly-checkpoint.json b/src/main/resources/mappings/anomaly-checkpoint.json index 5e515a803..162d28c6d 100644 --- a/src/main/resources/mappings/anomaly-checkpoint.json +++ b/src/main/resources/mappings/anomaly-checkpoint.json @@ -1,7 +1,7 @@ { "dynamic": true, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "detectorId": { @@ -35,6 +35,46 @@ }, "modelV2": { "type": "text" + }, + "samples": { + "type": "nested", + "properties": { + "value_list": { + "type": "nested", + "properties": { + "feature_id": { + "type": "keyword" + }, + "data": { + "type": "double" + } + } + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } + }, + "last_processed_sample": { + "type": "nested", + "properties": { + "value_list": { + "type": "double" + }, + "data_start_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + }, + "data_end_time": { + "type": "date", + "format": "strict_date_time||epoch_millis" + } + } } } } diff --git a/src/main/resources/mappings/anomaly-detectors.json b/src/main/resources/mappings/config.json similarity index 90% rename from src/main/resources/mappings/anomaly-detectors.json rename to src/main/resources/mappings/config.json index 7db1e6d08..c64a697e7 100644 --- a/src/main/resources/mappings/anomaly-detectors.json +++ b/src/main/resources/mappings/config.json @@ -150,6 +150,23 @@ }, "detector_type": { "type": "keyword" + }, + "forecast_interval": { + "properties": { + "period": { + "properties": { + "interval": { + "type": "integer" + }, + "unit": { + "type": "keyword" + } + } + } + } + }, + "horizon": { + "type": "integer" } } } diff --git a/src/main/resources/mappings/anomaly-detector-jobs.json b/src/main/resources/mappings/job.json similarity index 96% rename from src/main/resources/mappings/anomaly-detector-jobs.json rename to src/main/resources/mappings/job.json index fb26d56d2..5783c701d 100644 --- a/src/main/resources/mappings/anomaly-detector-jobs.json +++ b/src/main/resources/mappings/job.json @@ -1,7 +1,7 @@ { "dynamic": false, "_meta": { - "schema_version": 3 + "schema_version": 4 }, "properties": { "schema_version": { @@ -100,6 +100,9 @@ } } } + }, + "type": { + "type": "keyword" } } } diff --git a/src/test/java/org/opensearch/StreamInputOutputTests.java b/src/test/java/org/opensearch/StreamInputOutputTests.java index a1906c43f..e7e0433e5 100644 --- a/src/test/java/org/opensearch/StreamInputOutputTests.java +++ b/src/test/java/org/opensearch/StreamInputOutputTests.java @@ -27,12 +27,10 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileRequest; import org.opensearch.ad.transport.EntityProfileResponse; -import org.opensearch.ad.transport.EntityResultRequest; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.transport.RCFResultResponse; @@ -43,6 +41,7 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; /** * Put in core package so that we can using Version's package private constructor @@ -50,7 +49,7 @@ */ public class StreamInputOutputTests extends AbstractTimeSeriesTest { // public static Version V_1_1_0 = new Version(1010099, org.apache.lucene.util.Version.LUCENE_8_8_2); - private EntityResultRequest entityResultRequest; + private EntityADResultRequest entityResultRequest; private String detectorId; private long start, end; private Map entities; @@ -98,7 +97,7 @@ private void setUpEntityResultRequest() { entities.put(entity, feature); start = 10L; end = 20L; - entityResultRequest = new EntityResultRequest(detectorId, entities, start, end); + entityResultRequest = new EntityADResultRequest(detectorId, entities, start, end); } /** @@ -110,8 +109,8 @@ public void testDeSerializeEntityResultRequest() throws IOException { entityResultRequest.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - EntityResultRequest readRequest = new EntityResultRequest(streamInput); - assertThat(readRequest.getId(), equalTo(detectorId)); + EntityADResultRequest readRequest = new EntityADResultRequest(streamInput); + assertThat(readRequest.getConfigId(), equalTo(detectorId)); assertThat(readRequest.getStart(), equalTo(start)); assertThat(readRequest.getEnd(), equalTo(end)); assertTrue(areEqualWithArrayValue(readRequest.getEntities(), entities)); diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java index e2904c319..f56374bca 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/IndexAnomalyDetectorActionHandlerTests.java @@ -44,14 +44,18 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.ADNodeStateManager; +import org.opensearch.ad.TestHelpers; import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.IndexAnomalyDetectorResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -67,6 +71,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.task.TaskManager; import org.opensearch.transport.TransportService; /** @@ -98,7 +103,7 @@ public class IndexAnomalyDetectorActionHandlerTests extends AbstractTimeSeriesTe private Integer maxAnomalyFeatures; private Settings settings; private RestRequest.Method method; - private ADTaskManager adTaskManager; + private TaskManager adTaskManager; private SearchFeatureDao searchFeatureDao; private Clock clock; @@ -123,7 +128,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); clientMock = spy(new NodeClient(settings, threadPool)); clock = mock(Clock.class); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); transportService = mock(TransportService.class); @@ -203,7 +208,7 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { // we can also use spy to overstep the final methods NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new IndexAnomalyDetectorActionHandler( @@ -240,7 +245,7 @@ public void testMoreThanTenThousandSingleEntityDetectors() throws IOException { String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); @@ -281,7 +286,7 @@ public void doE } } }; - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( @@ -315,7 +320,7 @@ public void doE verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof Exception); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } @SuppressWarnings("unchecked") @@ -368,7 +373,7 @@ public void doE }; NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); handler = new IndexAnomalyDetectorActionHandler( @@ -463,7 +468,7 @@ public void doE }; NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); ClusterName clusterName = new ClusterName("test"); ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().build()).build(); @@ -503,7 +508,7 @@ public void doE if (fieldTypeName.equals(CommonName.IP_TYPE) || fieldTypeName.equals(CommonName.KEYWORD_TYPE)) { assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG)); } else { - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.CATEGORICAL_FIELD_TYPE_ERR_MSG)); + assertTrue(value.getMessage().contains(CommonMessages.CATEGORICAL_FIELD_TYPE_ERR_MSG)); } } @@ -572,7 +577,7 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { // we can also use spy to overstep the final methods NodeClient client = getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new IndexAnomalyDetectorActionHandler( @@ -606,11 +611,7 @@ public void testMoreThanTenMultiEntityDetectors() throws IOException { Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } @@ -695,7 +696,7 @@ public void testTenMultiEntityDetectorsUpdateSingleEntityAdToMulti() throws IOEx verify(channel).onFailure(response.capture()); Exception value = response.getValue(); assertTrue(value instanceof IllegalArgumentException); - assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG)); + assertTrue(value.getMessage().contains(IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG)); } @Ignore diff --git a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java index 2869943b6..e9ab2234f 100644 --- a/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java +++ b/src/test/java/org/opensearch/action/admin/indices/mapping/get/ValidateAnomalyDetectorActionHandlerTests.java @@ -32,16 +32,19 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ADNodeStateManager; +import org.opensearch.ad.TestHelpers; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.IndexAnomalyDetectorActionHandler; import org.opensearch.ad.rest.handler.ValidateAnomalyDetectorActionHandler; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.ValidateAnomalyDetectorResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.cluster.service.ClusterService; @@ -53,7 +56,10 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.feature.SearchFeatureDao; import org.opensearch.timeseries.model.ValidationAspect; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; @@ -76,7 +82,7 @@ public class ValidateAnomalyDetectorActionHandlerTests extends AbstractTimeSerie protected Integer maxAnomalyFeatures; protected Settings settings; protected RestRequest.Method method; - protected ADTaskManager adTaskManager; + protected TaskManager adTaskManager; protected SearchFeatureDao searchFeatureDao; protected Clock clock; @@ -143,7 +149,7 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc .getCustomNodeClient(detectorResponse, userIndexResponse, singleEntityDetector, threadPool); NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new ValidateAnomalyDetectorActionHandler( @@ -174,7 +180,7 @@ public void testValidateMoreThanThousandSingleEntityDetectorLimit() throws IOExc String errorMsg = String .format( Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_ENTITY_DETECTORS_PREFIX_MSG, + IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_SINGLE_STREAM_DETECTORS_PREFIX_MSG, maxSingleEntityAnomalyDetectors ); assertTrue(value.getMessage().contains(errorMsg)); @@ -197,7 +203,7 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio NodeClient client = IndexAnomalyDetectorActionHandlerTests .getCustomNodeClient(detectorResponse, userIndexResponse, detector, threadPool); NodeClient clientSpy = spy(client); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); SecurityClientUtil clientUtil = new SecurityClientUtil(nodeStateManager, settings); handler = new ValidateAnomalyDetectorActionHandler( @@ -226,11 +232,7 @@ public void testValidateMoreThanTenMultiEntityDetectorsLimit() throws IOExceptio Exception value = response.getValue(); assertTrue(value instanceof ValidationException); String errorMsg = String - .format( - Locale.ROOT, - IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_MULTI_ENTITY_DETECTORS_PREFIX_MSG, - maxMultiEntityAnomalyDetectors - ); + .format(Locale.ROOT, IndexAnomalyDetectorActionHandler.EXCEEDED_MAX_HC_DETECTORS_PREFIX_MSG, maxMultiEntityAnomalyDetectors); assertTrue(value.getMessage().contains(errorMsg)); } } diff --git a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java index f28de4547..d20953d66 100644 --- a/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AbstractProfileRunnerTests.java @@ -36,7 +36,6 @@ import org.opensearch.ad.model.DetectorProfileName; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -46,6 +45,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; public class AbstractProfileRunnerTests extends AbstractTimeSeriesTest { @@ -178,7 +178,7 @@ public void setUp() throws Exception { Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); detectorIntervalMin = 3; detectorGetReponse = mock(GetResponse.class); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java index 074b6ee86..f1dc6ad00 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorJobRunnerTests.java @@ -22,7 +22,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.NUM_MIN_SAMPLES; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import java.io.IOException; @@ -56,15 +55,12 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -85,12 +81,16 @@ import org.opensearch.jobscheduler.spi.utils.LockService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.JobProcessor; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import com.google.common.collect.ImmutableList; @@ -109,12 +109,12 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { private LockService lockService; @Mock - private AnomalyDetectorJob jobParameter; + private Job jobParameter; @Mock private JobExecutionContext context; - private AnomalyDetectorJobRunner runner = AnomalyDetectorJobRunner.getJobRunnerInstance(); + private JobProcessor runner = JobProcessor.getJobRunnerInstance(); @Mock private ThreadPool mockedThreadPool; @@ -125,7 +125,7 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { private Iterator backoff; @Mock - private AnomalyIndexHandler anomalyResultHandler; + private TimeSeriesResultIndexingHandler anomalyResultHandler; @Mock private ADTaskManager adTaskManager; @@ -141,7 +141,7 @@ public class AnomalyDetectorJobRunnerTests extends AbstractTimeSeriesTest { private ADTaskCacheManager adTaskCacheManager; @Mock - private NodeStateManager nodeStateManager; + private ADNodeStateManager nodeStateManager; private ADIndexManagement anomalyDetectionIndices; @@ -159,7 +159,7 @@ public static void tearDownAfterClass() { @Before public void setup() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyDetectorJobRunner.class); + super.setUpLog4jForJUnit(JobProcessor.class); MockitoAnnotations.initMocks(this); ThreadFactory threadFactory = OpenSearchExecutors.daemonThreadFactory(OpenSearchExecutors.threadName("node1", "test-ad")); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -169,7 +169,7 @@ public void setup() throws Exception { Mockito.doReturn(threadContext).when(mockedThreadPool).getThreadContext(); runner.setThreadPool(mockedThreadPool); runner.setClient(client); - runner.setAdTaskManager(adTaskManager); + runner.setTaskManager(adTaskManager); Settings settings = Settings .builder() @@ -183,7 +183,7 @@ public void setup() throws Exception { anomalyDetectionIndices = mock(ADIndexManagement.class); - runner.setAnomalyDetectionIndices(anomalyDetectionIndices); + runner.setIndexManagement(anomalyDetectionIndices); lockService = new LockService(client, clusterService); doReturn(lockService).when(context).getLockService(); @@ -194,7 +194,7 @@ public void setup() throws Exception { ActionListener listener = (ActionListener) args[1]; if (request.index().equals(CommonName.JOB_INDEX)) { - AnomalyDetectorJob job = TestHelpers.randomAnomalyDetectorJob(true); + Job job = TestHelpers.randomAnomalyDetectorJob(true); listener.onResponse(TestHelpers.createGetResponse(job, randomAlphaOfLength(5), CommonName.JOB_INDEX)); } return null; @@ -230,7 +230,7 @@ public void setup() throws Exception { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); runner.setNodeStateManager(nodeStateManager); recorder = new ExecuteADResultResponseRecorder( @@ -244,7 +244,7 @@ public void setup() throws Exception { adTaskCacheManager, 32 ); - runner.setExecuteADResultResponseRecorder(recorder); + runner.setExecuteResultResponseRecorder(recorder); } @Rule @@ -289,7 +289,7 @@ public void testRunJobWithLockDuration() throws InterruptedException { @Test public void testRunAdJobWithNullLock() { LockModel lock = null; - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, never()).execute(any(), any(), any()); } @@ -297,7 +297,7 @@ public void testRunAdJobWithNullLock() { public void testRunAdJobWithLock() { LockModel lock = new LockModel("indexName", "jobId", Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); } @@ -307,7 +307,7 @@ public void testRunAdJobWithExecuteException() { doThrow(RuntimeException.class).when(client).execute(any(), any(), any()); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusMillis(1000 * 60), Instant.now(), recorder, detector); verify(client, times(1)).execute(any(), any(), any()); assertTrue(testAppender.containsMessage("Failed to execute AD job")); } @@ -392,7 +392,7 @@ private void testRunAdJobWithEndRunExceptionNowAndStopAdJob(boolean jobExists, b jobExists, BytesReference .bytes( - new AnomalyDetectorJob( + new Job( jobParameter.getName(), jobParameter.getSchedule(), jobParameter.getWindowDelay(), @@ -515,10 +515,10 @@ public void testRunAdJobWithEndRunExceptionNotNowAndRetryUntilStop() throws Inte }).when(client).execute(any(), any(), any()); for (int i = 0; i < 3; i++) { - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(i + 1, testAppender.countMessage("EndRunException happened for")); } - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); assertEquals(1, testAppender.countMessage("JobRunner will stop AD job due to EndRunException retry exceeds upper limit")); } @@ -551,7 +551,7 @@ public Instant confirmInitializedSetup() { Collections.singletonList(new FeatureData("123", "abc", 0d)), randomAlphaOfLength(4), // not fully initialized - Long.valueOf(AnomalyDetectorSettings.NUM_MIN_SAMPLES - 1), + Long.valueOf(TimeSeriesSettings.NUM_MIN_SAMPLES - 1), randomLong(), // not an HC detector false, @@ -578,16 +578,16 @@ public void testFailtoFindDetector() { ActionListener> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); - verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); - verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getConfig(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getJob(any(String.class), any(ActionListener.class)); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get detector")); @@ -601,22 +601,22 @@ public void testFailtoFindJob() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); + ActionListener> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getJob(any(String.class), any(ActionListener.class)); LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); - verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); - verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getConfig(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getJob(any(String.class), any(ActionListener.class)); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get job")); @@ -630,16 +630,16 @@ public void testEmptyDetector() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.empty()); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); - verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); - verify(nodeStateManager, times(0)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getConfig(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(0)).getJob(any(String.class), any(ActionListener.class)); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get detector")); @@ -653,22 +653,22 @@ public void testEmptyJob() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.empty()); return null; - }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getJob(any(String.class), any(ActionListener.class)); LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(adTaskCacheManager, times(1)).hasQueriedResultIndex(anyString()); - verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); - verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getConfig(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getJob(any(String.class), any(ActionListener.class)); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); assertEquals(1, testAppender.countMessage("Fail to confirm rcf update")); assertTrue(testAppender.containExceptionMsg(TimeSeriesException.class, "fail to get job")); @@ -688,13 +688,13 @@ public void testMarkResultIndexQueried() throws IOException { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); doAnswer(invocation -> { - ActionListener> listener = invocation.getArgument(1); + ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null))); return null; - }).when(nodeStateManager).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getJob(any(String.class), any(ActionListener.class)); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -728,7 +728,7 @@ public void testMarkResultIndexQueried() throws IOException { ) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - MemoryTracker memoryTracker = mock(MemoryTracker.class); + ADMemoryTracker memoryTracker = mock(ADMemoryTracker.class); adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); // init real time task cache for the detector. We will do this during AnomalyResultTransportAction. @@ -752,17 +752,17 @@ public void testMarkResultIndexQueried() throws IOException { LockModel lock = new LockModel(CommonName.JOB_INDEX, jobParameter.getName(), Instant.now(), 10, false); - runner.runAdJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); + runner.runJob(jobParameter, lockService, lock, Instant.now().minusSeconds(60), executionStartTime, recorder, detector); verify(client, times(1)).execute(eq(AnomalyResultAction.INSTANCE), any(), any()); verify(client, times(1)).search(any(), any()); - verify(nodeStateManager, times(1)).getAnomalyDetector(any(String.class), any(ActionListener.class)); - verify(nodeStateManager, times(1)).getAnomalyDetectorJob(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getConfig(any(String.class), any(ActionListener.class)); + verify(nodeStateManager, times(1)).getJob(any(String.class), any(ActionListener.class)); ArgumentCaptor totalUpdates = ArgumentCaptor.forClass(Long.class); verify(adTaskManager, times(1)) .updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), totalUpdates.capture(), any(), any(), any()); - assertEquals(NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); + assertEquals(TimeSeriesSettings.NUM_MIN_SAMPLES, totalUpdates.getValue().longValue()); assertEquals(true, adTaskCacheManager.hasQueriedResultIndex(detector.getId())); } } diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java index 5d3c54541..fb2201b88 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorProfileRunnerTests.java @@ -37,15 +37,12 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.FailedNodeException; import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.ProfileAction; @@ -53,7 +50,6 @@ import org.opensearch.ad.transport.ProfileResponse; import org.opensearch.ad.transport.RCFPollingAction; import org.opensearch.ad.transport.RCFPollingResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; @@ -66,6 +62,7 @@ import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; import org.opensearch.transport.RemoteTransportException; public class AnomalyDetectorProfileRunnerTests extends AbstractProfileRunnerTests { @@ -98,12 +95,12 @@ private void setUpClientGet( ErrorResultStatus errorResultStatus ) throws IOException { detector = TestHelpers.randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES)); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(anyString(), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); runner = new AnomalyDetectorProfileRunner( client, @@ -137,7 +134,7 @@ private void setUpClientGet( break; } } else if (request.index().equals(CommonName.JOB_INDEX)) { - AnomalyDetectorJob job = null; + Job job = null; switch (jobStatus) { case INDEX_NOT_EXIT: listener.onFailure(new IndexNotFoundException(CommonName.JOB_INDEX)); @@ -206,7 +203,7 @@ public void testDetectorNotExist() throws IOException, InterruptedException { public void testDisabledJobIndexTemplate(JobStatus status) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, status, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.DISABLED).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.DISABLED).build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); runner.profile(detector.getId(), ActionListener.wrap(response -> { @@ -227,7 +224,7 @@ public void testJobDisabled() throws IOException, InterruptedException { testDisabledJobIndexTemplate(JobStatus.DISABLED); } - public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorState expectedState) throws IOException, + public void testInitOrRunningStateTemplate(RCFPollingStatus status, ConfigState expectedState) throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, status, ErrorResultStatus.NO_ERROR); DetectorProfile expectedProfile = new DetectorProfile.Builder().state(expectedState).build(); @@ -248,34 +245,34 @@ public void testInitOrRunningStateTemplate(RCFPollingStatus status, DetectorStat } public void testResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_NOT_EXIT, ConfigState.INIT); } public void testRemoteResultNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INIT_NOT_EXIT, ConfigState.INIT); } public void testCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.INDEX_NOT_FOUND, ConfigState.INIT); } public void testRemoteCheckpointIndexNotExist() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.REMOTE_INDEX_NOT_FOUND, ConfigState.INIT); } public void testResultEmpty() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, DetectorState.INIT); + testInitOrRunningStateTemplate(RCFPollingStatus.EMPTY, ConfigState.INIT); } public void testResultGreaterThanZero() throws IOException, InterruptedException { - testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, DetectorState.RUNNING); + testInitOrRunningStateTemplate(RCFPollingStatus.INIT_DONE, ConfigState.RUNNING); } @SuppressWarnings("unchecked") public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus, Set profilesToCollect @@ -289,7 +286,7 @@ public void testErrorStateTemplate( Consumer> function = (Consumer>) args[2]; function.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); setUpClientExecuteRCFPollingAction(initStatus); setUpClientGet(DetectorStatus.EXIST, jobStatus, initStatus, status); @@ -320,7 +317,7 @@ public void testErrorStateTemplate( public void testErrorStateTemplate( RCFPollingStatus initStatus, ErrorResultStatus status, - DetectorState state, + ConfigState state, String error, JobStatus jobStatus ) throws IOException, @@ -329,14 +326,14 @@ public void testErrorStateTemplate( } public void testRunningNoError() throws IOException, InterruptedException { - testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, DetectorState.RUNNING, null, JobStatus.ENABLED); + testErrorStateTemplate(RCFPollingStatus.INIT_DONE, ErrorResultStatus.NO_ERROR, ConfigState.RUNNING, null, JobStatus.ENABLED); } public void testRunningWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.INIT_DONE, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.RUNNING, + ConfigState.RUNNING, noFullShingleError, JobStatus.ENABLED ); @@ -346,7 +343,7 @@ public void testDisabledForStateError() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED ); @@ -356,7 +353,7 @@ public void testDisabledForStateInit() throws IOException, InterruptedException testErrorStateTemplate( RCFPollingStatus.INITTING, ErrorResultStatus.STOPPED_ERROR, - DetectorState.DISABLED, + ConfigState.DISABLED, stoppedError, JobStatus.DISABLED, stateInitProgress @@ -367,7 +364,7 @@ public void testInitWithError() throws IOException, InterruptedException { testErrorStateTemplate( RCFPollingStatus.EMPTY, ErrorResultStatus.SHINGLE_ERROR, - DetectorState.INIT, + ConfigState.INIT, noFullShingleError, JobStatus.ENABLED ); @@ -539,7 +536,7 @@ public void testProfileModels() throws InterruptedException, IOException { public void testInitProgress() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -558,7 +555,7 @@ public void testInitProgress() throws IOException, InterruptedException { public void testInitProgressFailImmediately() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.NO_DOC, JobStatus.ENABLED, RCFPollingStatus.INITTING, ErrorResultStatus.NO_ERROR); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); // 123 / 128 rounded to 96% InitProgressProfile profile = new InitProgressProfile("96%", neededSamples * detectorIntervalMin, neededSamples); @@ -578,7 +575,7 @@ public void testInitProgressFailImmediately() throws IOException, InterruptedExc public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.EMPTY, ErrorResultStatus.NO_ERROR); DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", detectorIntervalMin * requiredSamples, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -600,7 +597,7 @@ public void testInitNoUpdateNoIndex() throws IOException, InterruptedException { public void testInitNoIndex() throws IOException, InterruptedException { setUpClientGet(DetectorStatus.EXIST, JobStatus.ENABLED, RCFPollingStatus.INDEX_NOT_FOUND, ErrorResultStatus.NO_ERROR); DetectorProfile expectedProfile = new DetectorProfile.Builder() - .state(DetectorState.INIT) + .state(ConfigState.INIT) .initProgress(new InitProgressProfile("0%", 0, requiredSamples)) .build(); final CountDownLatch inProgressLatch = new CountDownLatch(1); diff --git a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java index 6ff4d604d..56b44262d 100644 --- a/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java +++ b/src/test/java/org/opensearch/ad/AnomalyDetectorRestTestCase.java @@ -26,7 +26,6 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.RestClient; @@ -45,6 +44,7 @@ import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -258,7 +258,7 @@ public ToXContentObject[] getAnomalyDetector( String id = null; Long version = null; AnomalyDetector detector = null; - AnomalyDetectorJob detectorJob = null; + Job detectorJob = null; ADTask realtimeAdTask = null; ADTask historicalAdTask = null; while (parser.nextToken() != XContentParser.Token.END_OBJECT) { @@ -275,7 +275,7 @@ public ToXContentObject[] getAnomalyDetector( detector = AnomalyDetector.parse(parser); break; case "anomaly_detector_job": - detectorJob = AnomalyDetectorJob.parse(parser); + detectorJob = Job.parse(parser); break; case "realtime_detection_task": if (parser.currentToken() != XContentParser.Token.VALUE_NULL) { diff --git a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java index 1004b01a4..cab7ddd9a 100644 --- a/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/EntityProfileRunnerTests.java @@ -35,16 +35,13 @@ import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.EntityProfile; import org.opensearch.ad.model.EntityProfileName; import org.opensearch.ad.model.EntityState; import org.opensearch.ad.model.InitProgressProfile; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.ad.transport.EntityProfileAction; import org.opensearch.ad.transport.EntityProfileResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.Settings; @@ -61,6 +58,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { private AnomalyDetector detector; @@ -74,7 +72,7 @@ public class EntityProfileRunnerTests extends AbstractTimeSeriesTest { private String detectorId; private String entityValue; private int requiredSamples; - private AnomalyDetectorJob job; + private Job job; private int smallUpdates; private String categoryField; @@ -129,12 +127,12 @@ public void setUp() throws Exception { requiredSamples = 128; client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); runner = new EntityProfileRunner(client, clientUtil, xContentRegistry(), requiredSamples); diff --git a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java index b19eb7242..d702ad27d 100644 --- a/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java +++ b/src/test/java/org/opensearch/ad/HistoricalAnalysisIntegTestCase.java @@ -12,9 +12,6 @@ package org.opensearch.ad; import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD; -import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD; -import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD; -import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; @@ -37,13 +34,9 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.mock.plugin.MockReindexPlugin; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.transport.AnomalyDetectorJobAction; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -57,6 +50,11 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import com.google.common.collect.ImmutableList; @@ -120,6 +118,7 @@ public void ingestTestData( } } + @Override public Feature maxValueFeature() throws IOException { AggregationBuilder aggregationBuilder = TestHelpers.parseAggregation("{\"test\":{\"max\":{\"field\":\"" + valueField + "\"}}}"); return new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); @@ -135,20 +134,14 @@ public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, DateR } public ADTask randomCreatedADTask(String taskId, AnomalyDetector detector, String detectorId, DateRange detectionDateRange) { - return randomADTask(taskId, detector, detectorId, detectionDateRange, ADTaskState.CREATED); + return randomADTask(taskId, detector, detectorId, detectionDateRange, TaskState.CREATED); } - public ADTask randomADTask( - String taskId, - AnomalyDetector detector, - String detectorId, - DateRange detectionDateRange, - ADTaskState state - ) { + public ADTask randomADTask(String taskId, AnomalyDetector detector, String detectorId, DateRange detectionDateRange, TaskState state) { ADTask.Builder builder = ADTask .builder() .taskId(taskId) - .taskType(ADTaskType.HISTORICAL_SINGLE_ENTITY.name()) + .taskType(ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR.name()) .detectorId(detectorId) .detectionDateRange(detectionDateRange) .detector(detector) @@ -158,12 +151,12 @@ public ADTask randomADTask( .isLatest(true) .startedBy(randomAlphaOfLength(5)) .executionStartTime(Instant.now().minus(randomLongBetween(10, 100), ChronoUnit.MINUTES)); - if (ADTaskState.FINISHED == state) { + if (TaskState.FINISHED == state) { setPropertyForNotRunningTask(builder); - } else if (ADTaskState.FAILED == state) { + } else if (TaskState.FAILED == state) { setPropertyForNotRunningTask(builder); builder.error(randomAlphaOfLength(5)); - } else if (ADTaskState.STOPPED == state) { + } else if (TaskState.STOPPED == state) { setPropertyForNotRunningTask(builder); builder.error(randomAlphaOfLength(5)); builder.stoppedBy(randomAlphaOfLength(5)); @@ -185,14 +178,14 @@ public List searchADTasks(String detectorId, String parentTaskId, Boolea BoolQueryBuilder query = new BoolQueryBuilder(); query.filter(new TermQueryBuilder(DETECTOR_ID_FIELD, detectorId)); if (isLatest != null) { - query.filter(new TermQueryBuilder(IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); } if (parentTaskId != null) { - query.filter(new TermQueryBuilder(PARENT_TASK_ID_FIELD, parentTaskId)); + query.filter(new TermQueryBuilder(TimeSeriesTask.PARENT_TASK_ID_FIELD, parentTaskId)); } SearchRequest searchRequest = new SearchRequest(); SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(query).sort(EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); + sourceBuilder.query(query).sort(TimeSeriesTask.EXECUTION_START_TIME_FIELD, SortOrder.DESC).trackTotalHits(true).size(size); searchRequest.source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); SearchResponse searchResponse = client().search(searchRequest).actionGet(); Iterator iterator = searchResponse.getHits().iterator(); @@ -212,7 +205,7 @@ public ADTask getADTask(String taskId) throws IOException { return adTask; } - public AnomalyDetectorJob getADJob(String detectorId) throws IOException { + public Job getADJob(String detectorId) throws IOException { return toADJob(getDoc(CommonName.JOB_INDEX, detectorId)); } @@ -220,8 +213,8 @@ public ADTask toADTask(GetResponse doc) throws IOException { return ADTask.parse(TestHelpers.parser(doc.getSourceAsString())); } - public AnomalyDetectorJob toADJob(GetResponse doc) throws IOException { - return AnomalyDetectorJob.parse(TestHelpers.parser(doc.getSourceAsString())); + public Job toADJob(GetResponse doc) throws IOException { + return Job.parse(TestHelpers.parser(doc.getSourceAsString())); } public ADTask startHistoricalAnalysis(Instant startTime, Instant endTime) throws IOException { @@ -229,29 +222,15 @@ public ADTask startHistoricalAnalysis(Instant startTime, Instant endTime) throws AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } public ADTask startHistoricalAnalysis(String detectorId, Instant startTime, Instant endTime) throws IOException { DateRange dateRange = new DateRange(startTime, endTime); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); return getADTask(response.getId()); } } diff --git a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java index f21b74b11..7ae0edc6b 100644 --- a/src/test/java/org/opensearch/ad/MemoryTrackerTests.java +++ b/src/test/java/org/opensearch/ad/MemoryTrackerTests.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.HashSet; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -29,6 +28,7 @@ import org.opensearch.monitor.jvm.JvmInfo.Mem; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -44,7 +44,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase { int numMinSamples; int shingleSize; int dimension; - MemoryTracker tracker; + ADMemoryTracker tracker; long expectedRCFModelSize; String detectorId; long largeHeapSize; @@ -57,7 +57,7 @@ public class MemoryTrackerTests extends OpenSearchTestCase { double modelDesiredSizePercentage; JvmService jvmService; AnomalyDetector detector; - ADCircuitBreakerService circuitBreaker; + CircuitBreakerService circuitBreaker; @Override public void setUp() throws Exception { @@ -85,10 +85,10 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); modelMaxPercen = 0.1f; - Settings settings = Settings.builder().put(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.getKey(), modelMaxPercen).build(); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.getKey(), modelMaxPercen).build(); ClusterSettings clusterSettings = new ClusterSettings( settings, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -115,20 +115,20 @@ public void setUp() throws Exception { when(detector.getEnabledFeatureIds()).thenReturn(Collections.singletonList("a")); when(detector.getShingleSize()).thenReturn(1); - circuitBreaker = mock(ADCircuitBreakerService.class); + circuitBreaker = mock(CircuitBreakerService.class); when(circuitBreaker.isOpen()).thenReturn(false); } private void setUpBigHeap() { ByteSizeValue value = new ByteSizeValue(largeHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); + tracker = new ADMemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); } private void setUpSmallHeap() { ByteSizeValue value = new ByteSizeValue(smallHeapSize); when(mem.getHeapMax()).thenReturn(value); - tracker = new MemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); + tracker = new ADMemoryTracker(jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, clusterService, circuitBreaker); } public void testEstimateModelSize() { @@ -167,7 +167,7 @@ public void testEstimateModelSize() { .parallelExecutionEnabled(false) .compact(true) .precision(Precision.FLOAT_32) - .boundingBoxCacheFraction(AnomalyDetectorSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) + .boundingBoxCacheFraction(TimeSeriesSettings.BATCH_BOUNDING_BOX_CACHE_RATIO) .internalShinglingEnabled(false) // same with dimension for opportunistic memory saving .shingleSize(1) @@ -301,10 +301,10 @@ public void testCanAllocate() { assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen + 10))); long bytesToUse = 100_000; - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, false, ADMemoryTracker.Origin.HC_DETECTOR); assertTrue(!tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); - tracker.releaseMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); + tracker.releaseMemory(bytesToUse, false, ADMemoryTracker.Origin.HC_DETECTOR); assertTrue(tracker.canAllocate((long) (largeHeapSize * modelMaxPercen))); } @@ -318,11 +318,11 @@ public void testMemoryToShed() { long bytesToUse = 100_000; assertEquals(bytesToUse, tracker.getHeapLimit()); assertEquals((long) (smallHeapSize * modelDesiredSizePercentage), tracker.getDesiredModelSize()); - tracker.consumeMemory(bytesToUse, false, MemoryTracker.Origin.HC_DETECTOR); - tracker.consumeMemory(bytesToUse, true, MemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, false, ADMemoryTracker.Origin.HC_DETECTOR); + tracker.consumeMemory(bytesToUse, true, ADMemoryTracker.Origin.HC_DETECTOR); assertEquals(2 * bytesToUse, tracker.getTotalMemoryBytes()); assertEquals(bytesToUse, tracker.memoryToShed()); - assertTrue(!tracker.syncMemoryState(MemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); + assertTrue(!tracker.syncMemoryState(ADMemoryTracker.Origin.HC_DETECTOR, 2 * bytesToUse, bytesToUse)); } } diff --git a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java index 80ef180ed..7b83fdda9 100644 --- a/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java +++ b/src/test/java/org/opensearch/ad/MultiEntityProfileRunnerTests.java @@ -46,12 +46,10 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.DetectorState; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.ad.transport.AnomalyResultTests; import org.opensearch.ad.transport.ProfileAction; @@ -66,6 +64,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -92,7 +91,7 @@ public class MultiEntityProfileRunnerTests extends AbstractTimeSeriesTest { private String model0Id; private int shingleSize; - private AnomalyDetectorJob job; + private Job job; private TransportService transportService; private ADTaskManager adTaskManager; @@ -118,7 +117,7 @@ public void setUp() throws Exception { super.setUp(); client = mock(Client.class); Clock clock = mock(Clock.class); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); nodeFilter = mock(DiscoveryNodeFilterer.class); requiredSamples = 128; @@ -135,7 +134,7 @@ public void setUp() throws Exception { function.accept(Optional.of(TestHelpers.randomAdTask())); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(any(), any(), any(), any(), anyBoolean(), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(any(), any(), any(), any(), anyBoolean(), any()); runner = new AnomalyDetectorProfileRunner( client, clientUtil, @@ -283,7 +282,7 @@ public void testInit() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.INIT).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.INIT).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -300,7 +299,7 @@ public void testRunning() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); @@ -321,7 +320,7 @@ public void testResultIndexFinalTruth() throws InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - DetectorProfile expectedProfile = new DetectorProfile.Builder().state(DetectorState.RUNNING).build(); + DetectorProfile expectedProfile = new DetectorProfile.Builder().state(ConfigState.RUNNING).build(); runner.profile(detectorId, ActionListener.wrap(response -> { assertEquals(expectedProfile, response); inProgressLatch.countDown(); diff --git a/src/test/java/org/opensearch/ad/NodeStateManagerTests.java b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java index 9cad7d5eb..58300f194 100644 --- a/src/test/java/org/opensearch/ad/NodeStateManagerTests.java +++ b/src/test/java/org/opensearch/ad/NodeStateManagerTests.java @@ -18,8 +18,8 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.IOException; import java.time.Clock; @@ -44,10 +44,8 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -65,6 +63,7 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.model.Job; import com.google.common.collect.ImmutableMap; @@ -84,7 +83,7 @@ public class NodeStateManagerTests extends AbstractTimeSeriesTest { private GetResponse checkpointResponse; private ClusterService clusterService; private ClusterSettings clusterSettings; - private AnomalyDetectorJob jobToCheck; + private Job jobToCheck; @Override protected NamedXContentRegistry xContentRegistry() { @@ -119,8 +118,8 @@ public void setUp() throws Exception { clientUtil = new ClientUtil(Settings.EMPTY, client, throttler, mock(ThreadPool.class)); Set> nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(AD_BACKOFF_MINUTES); clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); DiscoveryNode discoveryNode = new DiscoveryNode( @@ -132,7 +131,7 @@ public void setUp() throws Exception { ); clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - stateManager = new NodeStateManager(client, xContentRegistry(), settings, clientUtil, clock, duration, clusterService); + stateManager = new ADNodeStateManager(client, xContentRegistry(), settings, clientUtil, clock, duration, clusterService); checkpointResponse = mock(GetResponse.class); jobToCheck = TestHelpers.randomAnomalyDetectorJob(true, Instant.ofEpochMilli(1602401500000L), null); @@ -205,7 +204,7 @@ private void setupCheckpoint(boolean responseExists) throws IOException { public void testGetLastError() throws IOException, InterruptedException { String error = "blah"; - assertEquals(NodeStateManager.NO_ERROR, stateManager.getLastDetectionError(adId)); + assertEquals(ADNodeStateManager.NO_ERROR, stateManager.getLastDetectionError(adId)); stateManager.setLastDetectionError(adId, error); assertEquals(error, stateManager.getLastDetectionError(adId)); } @@ -236,7 +235,7 @@ public void testMaintenanceDoNothing() { } public void testHasRunningQuery() throws IOException { - stateManager = new NodeStateManager( + stateManager = new ADNodeStateManager( client, xContentRegistry(), settings, @@ -257,7 +256,7 @@ public void testGetAnomalyDetector() throws IOException, InterruptedException { String detectorId = setupDetector(); final CountDownLatch inProgressLatch = new CountDownLatch(1); - stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getConfig(detectorId, ActionListener.wrap(asDetector -> { assertEquals(detectorToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { @@ -277,7 +276,7 @@ public void testRepeatedGetAnomalyDetector() throws IOException, InterruptedExce String detectorId = setupDetector(); final CountDownLatch inProgressLatch = new CountDownLatch(2); - stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getConfig(detectorId, ActionListener.wrap(asDetector -> { assertEquals(detectorToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { @@ -285,7 +284,7 @@ public void testRepeatedGetAnomalyDetector() throws IOException, InterruptedExce inProgressLatch.countDown(); })); - stateManager.getAnomalyDetector(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getConfig(detectorId, ActionListener.wrap(asDetector -> { assertEquals(detectorToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { @@ -363,7 +362,7 @@ public void testSettingUpdateMaxRetry() { // In setUp method, we mute after 3 tries assertTrue(!stateManager.isMuted(nodeId, adId)); - Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), "1").build(); + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), "1").build(); Settings.Builder target = Settings.builder(); clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); clusterSettings.applySettings(target.build()); @@ -381,7 +380,7 @@ public void testSettingUpdateBackOffMin() { assertTrue(stateManager.isMuted(nodeId, adId)); - Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.BACKOFF_MINUTES.getKey(), "1m").build(); + Settings newSettings = Settings.builder().put(AnomalyDetectorSettings.AD_BACKOFF_MINUTES.getKey(), "1m").build(); Settings.Builder target = Settings.builder(); clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); clusterSettings.applySettings(target.build()); @@ -412,7 +411,7 @@ private String setupJob() throws IOException { public void testGetAnomalyJob() throws IOException, InterruptedException { String detectorId = setupJob(); final CountDownLatch inProgressLatch = new CountDownLatch(1); - stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getJob(detectorId, ActionListener.wrap(asDetector -> { assertEquals(jobToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { @@ -432,7 +431,7 @@ public void testRepeatedGetAnomalyJob() throws IOException, InterruptedException String detectorId = setupJob(); final CountDownLatch inProgressLatch = new CountDownLatch(2); - stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getJob(detectorId, ActionListener.wrap(asDetector -> { assertEquals(jobToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { @@ -440,7 +439,7 @@ public void testRepeatedGetAnomalyJob() throws IOException, InterruptedException inProgressLatch.countDown(); })); - stateManager.getAnomalyDetectorJob(detectorId, ActionListener.wrap(asDetector -> { + stateManager.getJob(detectorId, ActionListener.wrap(asDetector -> { assertEquals(jobToCheck, asDetector.get()); inProgressLatch.countDown(); }, exception -> { diff --git a/src/test/java/org/opensearch/ad/NodeStateTests.java b/src/test/java/org/opensearch/ad/NodeStateTests.java index c48afdb76..1af7f0ff4 100644 --- a/src/test/java/org/opensearch/ad/NodeStateTests.java +++ b/src/test/java/org/opensearch/ad/NodeStateTests.java @@ -24,14 +24,14 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; public class NodeStateTests extends OpenSearchTestCase { - private NodeState state; + private ADNodeState state; private Clock clock; @Override public void setUp() throws Exception { super.setUp(); clock = mock(Clock.class); - state = new NodeState("123", clock); + state = new ADNodeState("123", clock); } private Duration duration = Duration.ofHours(1); diff --git a/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java b/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java index 7a5be47b6..848333f87 100644 --- a/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java +++ b/src/test/java/org/opensearch/ad/breaker/ADCircuitBreakerServiceTests.java @@ -25,11 +25,14 @@ import org.mockito.MockitoAnnotations; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.timeseries.breaker.CircuitBreaker; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.breaker.MemoryCircuitBreaker; public class ADCircuitBreakerServiceTests { @InjectMocks - private ADCircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService adCircuitBreakerService; @Mock JvmService jvmService; diff --git a/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java b/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java index e9249df82..6264808cc 100644 --- a/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java +++ b/src/test/java/org/opensearch/ad/breaker/MemoryCircuitBreakerTests.java @@ -21,6 +21,8 @@ import org.mockito.MockitoAnnotations; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; +import org.opensearch.timeseries.breaker.CircuitBreaker; +import org.opensearch.timeseries.breaker.MemoryCircuitBreaker; public class MemoryCircuitBreakerTests { diff --git a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java index d1dde1654..11df05c70 100644 --- a/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/ad/bwc/ADBackwardsCompatibilityIT.java @@ -42,15 +42,15 @@ import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; -import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.rest.ADRestTestUtils; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.client.Response; import org.opensearch.common.settings.Settings; import org.opensearch.core.rest.RestStatus; import org.opensearch.test.rest.OpenSearchRestTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Config; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.util.ExceptionUtil; import org.opensearch.timeseries.util.RestHandlerUtils; import com.google.common.collect.ImmutableList; @@ -70,6 +70,7 @@ public class ADBackwardsCompatibilityIT extends OpenSearchRestTestCase { private List runningRealtimeDetectors; private List historicalDetectors; + @Override @Before public void setUp() throws Exception { super.setUp(); @@ -197,7 +198,10 @@ public void testBackwardsCompatibility() throws Exception { List singleEntityDetectorResults = createRealtimeAnomalyDetectorsAndStart(SINGLE_ENTITY_DETECTOR); detectors.put(SINGLE_ENTITY_DETECTOR, singleEntityDetectorResults.get(0)); // Start historical analysis for single entity detector - startHistoricalAnalysisOnNewNode(singleEntityDetectorResults.get(0), ADTaskType.HISTORICAL_SINGLE_ENTITY.name()); + startHistoricalAnalysisOnNewNode( + singleEntityDetectorResults.get(0), + ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR.name() + ); // Create single category HC detector and start realtime job List singleCategoryHCResults = createRealtimeAnomalyDetectorsAndStart(SINGLE_CATEGORY_HC_DETECTOR); @@ -258,7 +262,7 @@ private void verifyAdTasks() throws InterruptedException, IOException { i++; for (String detectorId : runningRealtimeDetectors) { Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); - AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + Job job = (Job) jobAndTask.get(ANOMALY_DETECTOR_JOB); ADTask historicalTask = (ADTask) jobAndTask.get(HISTORICAL_ANALYSIS_TASK); ADTask realtimeTask = (ADTask) jobAndTask.get(REALTIME_TASK); assertTrue(job.isEnabled()); @@ -291,7 +295,7 @@ private void stopAndDeleteDetectors() throws Exception { } } Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); - AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + Job job = (Job) jobAndTask.get(ANOMALY_DETECTOR_JOB); ADTask historicalAdTask = (ADTask) jobAndTask.get(HISTORICAL_ANALYSIS_TASK); if (!job.isEnabled() && historicalAdTask.isDone()) { Response deleteDetectorResponse = deleteDetector(client(), detectorId); @@ -320,7 +324,7 @@ private void startRealtimeJobForHistoricalDetectorOnNewNode() throws IOException String jobId = startAnomalyDetectorDirectly(client(), detectorId); assertEquals(detectorId, jobId); Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); - AnomalyDetectorJob detectorJob = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + Job detectorJob = (Job) jobAndTask.get(ANOMALY_DETECTOR_JOB); assertTrue(detectorJob.isEnabled()); runningRealtimeDetectors.add(detectorId); } @@ -329,7 +333,7 @@ private void startRealtimeJobForHistoricalDetectorOnNewNode() throws IOException private void verifyAllRealtimeJobsRunning() throws IOException { for (String detectorId : runningRealtimeDetectors) { Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); - AnomalyDetectorJob detectorJob = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + Job detectorJob = (Job) jobAndTask.get(ANOMALY_DETECTOR_JOB); assertTrue(detectorJob.isEnabled()); } } @@ -434,7 +438,7 @@ private List startAnomalyDetector(Response response, boolean historicalD Map responseMap = entityAsMap(response); String detectorId = (String) responseMap.get("_id"); int version = (int) responseMap.get("_version"); - assertNotEquals("response is missing Id", AnomalyDetector.NO_ID, detectorId); + assertNotEquals("response is missing Id", Config.NO_ID, detectorId); assertTrue("incorrect version", version > 0); Response startDetectorResponse = TestHelpers @@ -452,7 +456,7 @@ private List startAnomalyDetector(Response response, boolean historicalD if (!historicalDetector) { Map jobAndTask = getDetectorWithJobAndTask(client(), detectorId); - AnomalyDetectorJob job = (AnomalyDetectorJob) jobAndTask.get(ANOMALY_DETECTOR_JOB); + Job job = (Job) jobAndTask.get(ANOMALY_DETECTOR_JOB); assertTrue(job.isEnabled()); runningRealtimeDetectors.add(detectorId); } else { diff --git a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java index 5045b45bb..346476775 100644 --- a/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java +++ b/src/test/java/org/opensearch/ad/caching/AbstractCacheTest.java @@ -22,31 +22,30 @@ import java.util.Random; import org.junit.Before; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointMaintainWorker; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; -import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.ratelimit.ADCheckpointMaintainWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; public class AbstractCacheTest extends AbstractTimeSeriesTest { protected String modelId1, modelId2, modelId3, modelId4; protected Entity entity1, entity2, entity3, entity4; - protected ModelState modelState1, modelState2, modelState3, modelState4; + protected ADModelState> modelState1, modelState2, modelState3, modelState4; protected String detectorId; protected AnomalyDetector detector; protected Clock clock; protected Duration detectorDuration; protected float initialPriority; - protected CacheBuffer cacheBuffer; + protected ADCacheBuffer cacheBuffer; protected long memoryPerEntity; - protected MemoryTracker memoryTracker; - protected CheckpointWriteWorker checkpointWriteQueue; - protected CheckpointMaintainWorker checkpointMaintainQueue; + protected ADMemoryTracker memoryTracker; + protected ADCheckpointWriteWorker checkpointWriteQueue; + protected ADCheckpointMaintainWorker checkpointMaintainQueue; protected Random random; protected int shingleSize; @@ -83,18 +82,18 @@ public void setUp() throws Exception { when(clock.instant()).thenReturn(Instant.now()); memoryPerEntity = 81920; - memoryTracker = mock(MemoryTracker.class); + memoryTracker = mock(ADMemoryTracker.class); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); - checkpointMaintainQueue = mock(CheckpointMaintainWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); + checkpointMaintainQueue = mock(ADCheckpointMaintainWorker.class); - cacheBuffer = new CacheBuffer( + cacheBuffer = new ADCacheBuffer( 1, 1, memoryPerEntity, memoryTracker, clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, detectorId, checkpointWriteQueue, checkpointMaintainQueue, @@ -103,38 +102,38 @@ public void setUp() throws Exception { initialPriority = cacheBuffer.getPriorityTracker().getUpdatedPriority(0); - modelState1 = new ModelState<>( - new EntityModel(entity1, new ArrayDeque<>(), null), + modelState1 = new ADModelState<>( + new createFromValueOnlySamples(entity1, new ArrayDeque<>(), null), modelId1, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); - modelState2 = new ModelState<>( - new EntityModel(entity2, new ArrayDeque<>(), null), + modelState2 = new ADModelState<>( + new createFromValueOnlySamples(entity2, new ArrayDeque<>(), null), modelId2, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); - modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + modelState3 = new ADModelState<>( + new createFromValueOnlySamples(entity3, new ArrayDeque<>(), null), modelId3, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); - modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + modelState4 = new ADModelState<>( + new createFromValueOnlySamples(entity4, new ArrayDeque<>(), null), modelId4, detectorId, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); diff --git a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java index 7332edf4b..515ae37af 100644 --- a/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java +++ b/src/test/java/org/opensearch/ad/caching/CacheBufferTests.java @@ -22,8 +22,8 @@ import java.util.Optional; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.ratelimit.CheckpointMaintainRequest; +import org.opensearch.ad.ADMemoryTracker; +import org.opensearch.timeseries.ratelimit.ModelRequest; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -69,21 +69,21 @@ public void testRemovalCandidate2() throws InterruptedException { cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId2, modelState2); cacheBuffer.put(modelId4, modelState4); - assertTrue(cacheBuffer.getModel(modelId2).isPresent()); + assertTrue(cacheBuffer.getModelState(modelId2).isPresent()); ArgumentCaptor memoryReleased = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); - ArgumentCaptor orign = ArgumentCaptor.forClass(MemoryTracker.Origin.class); + ArgumentCaptor orign = ArgumentCaptor.forClass(ADMemoryTracker.Origin.class); cacheBuffer.clear(); verify(memoryTracker, times(2)).releaseMemory(memoryReleased.capture(), reserved.capture(), orign.capture()); List capturedMemoryReleased = memoryReleased.getAllValues(); List capturedreserved = reserved.getAllValues(); - List capturedOrigin = orign.getAllValues(); + List capturedOrigin = orign.getAllValues(); assertEquals(3 * memoryPerEntity, capturedMemoryReleased.stream().reduce(0L, (a, b) -> a + b).intValue()); assertTrue(capturedreserved.get(0)); assertTrue(!capturedreserved.get(1)); - assertEquals(MemoryTracker.Origin.HC_DETECTOR, capturedOrigin.get(0)); + assertEquals(ADMemoryTracker.Origin.HC_DETECTOR, capturedOrigin.get(0)); assertTrue(!cacheBuffer.expired(Duration.ofHours(1))); } @@ -117,7 +117,7 @@ public void testMaintenance() { cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.maintenance(); assertEquals(3, cacheBuffer.getActiveEntities()); - assertEquals(3, cacheBuffer.getAllModels().size()); + assertEquals(3, cacheBuffer.getAllModelStates().size()); // the year of 2122, 100 years later to simulate we are gonna remove all cached entries when(clock.instant()).thenReturn(Instant.ofEpochSecond(4814540761L)); cacheBuffer.maintenance(); @@ -138,7 +138,7 @@ public void testMaintainByHourNothingToSave() { cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); - ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); cacheBuffer.maintenance(); verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); assertTrue(savedStates.getValue().isEmpty()); @@ -162,12 +162,12 @@ public void testMaintainByHourSaveOne() { cacheBuffer.put(modelId2, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); cacheBuffer.put(modelId3, MLUtil.randomModelState(new RandomModelStateConfig.Builder().priority(initialPriority).build())); - ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); + ArgumentCaptor> savedStates = ArgumentCaptor.forClass(List.class); cacheBuffer.maintenance(); verify(checkpointMaintainQueue, times(1)).putAll(savedStates.capture()); - List toSave = savedStates.getValue(); + List toSave = savedStates.getValue(); assertEquals(1, toSave.size()); - assertEquals(modelId1, toSave.get(0).getEntityModelId()); + assertEquals(modelId1, toSave.get(0).getModelId()); } /** diff --git a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java index 7774fb314..f6034d25d 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityCacheTests.java @@ -44,13 +44,8 @@ import org.apache.logging.log4j.Logger; import org.junit.Before; import org.mockito.ArgumentCaptor; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -63,16 +58,21 @@ import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.breaker.CircuitBreakerService; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; public class PriorityCacheTests extends AbstractCacheTest { private static final Logger LOG = LogManager.getLogger(PriorityCacheTests.class); EntityCache entityCache; - CheckpointDao checkpoint; - ModelManager modelManager; + ADCheckpointDao checkpoint; + ADModelManager modelManager; ClusterService clusterService; Settings settings; @@ -86,9 +86,9 @@ public class PriorityCacheTests extends AbstractCacheTest { public void setUp() throws Exception { super.setUp(); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); clusterService = mock(ClusterService.class); ClusterSettings settings = new ClusterSettings( @@ -98,10 +98,10 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, - AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ ) ) @@ -114,25 +114,25 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - EntityCache cache = new PriorityCache( + EntityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, memoryTracker, - AnomalyDetectorSettings.NUM_TREES, + TimeSeriesSettings.NUM_TREES, clock, clusterService, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, checkpointWriteQueue, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, checkpointMaintainQueue, Settings.EMPTY, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ ); - CacheProvider cacheProvider = new CacheProvider(); + HCCacheProvider cacheProvider = new HCCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -167,27 +167,27 @@ public void testCacheHit() { // ); // when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - memoryTracker = spy(new MemoryTracker(jvmService, modelMaxPercen, 0.002, clusterService, mock(ADCircuitBreakerService.class))); + memoryTracker = spy(new ADMemoryTracker(jvmService, modelMaxPercen, 0.002, clusterService, mock(CircuitBreakerService.class))); - EntityCache cache = new PriorityCache( + EntityCache cache = new ADPriorityCache( checkpoint, dedicatedCacheSize, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, AnomalyDetectorSettings.MAX_INACTIVE_ENTITIES, memoryTracker, - AnomalyDetectorSettings.NUM_TREES, + TimeSeriesSettings.NUM_TREES, clock, clusterService, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, checkpointWriteQueue, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, checkpointMaintainQueue, Settings.EMPTY, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ ); - CacheProvider cacheProvider = new CacheProvider(); + HCCacheProvider cacheProvider = new HCCacheProvider(); cacheProvider.set(cache); entityCache = cacheProvider.get(); @@ -198,24 +198,24 @@ public void testCacheHit() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getTotalActiveEntities()); assertEquals(1, entityCache.getAllModels().size()); - ModelState hitState = entityCache.get(modelState1.getModelId(), detector); - assertEquals(detectorId, hitState.getId()); - EntityModel model = hitState.getModel(); - assertEquals(false, model.getTrcf().isPresent()); - assertTrue(model.getSamples().isEmpty()); - modelState1.getModel().addSample(point); - assertTrue(Arrays.equals(point, model.getSamples().peek())); + ADModelState> hitState = entityCache.get(modelState1.getModelId(), detector); + assertEquals(detectorId, hitState.getConfigId()); + createFromValueOnlySamples model = hitState.getModel(); + assertEquals(false, model.getModel().isPresent()); + assertTrue(model.getValueOnlySamples().isEmpty()); + modelState1.getModel().addValueOnlySample(point); + assertTrue(Arrays.equals(point, model.getValueOnlySamples().peek())); ArgumentCaptor memoryConsumed = ArgumentCaptor.forClass(Long.class); ArgumentCaptor reserved = ArgumentCaptor.forClass(Boolean.class); - ArgumentCaptor origin = ArgumentCaptor.forClass(MemoryTracker.Origin.class); + ArgumentCaptor origin = ArgumentCaptor.forClass(ADMemoryTracker.Origin.class); // input dimension: 3, shingle: 4 long expectedMemoryPerEntity = 436828L; verify(memoryTracker, times(1)).consumeMemory(memoryConsumed.capture(), reserved.capture(), origin.capture()); assertEquals(dedicatedCacheSize * expectedMemoryPerEntity, memoryConsumed.getValue().intValue()); assertEquals(true, reserved.getValue().booleanValue()); - assertEquals(MemoryTracker.Origin.HC_DETECTOR, origin.getValue()); + assertEquals(ADMemoryTracker.Origin.HC_DETECTOR, origin.getValue()); // for (int i = 0; i < 2; i++) { // cacheProvider.get(modelId2, detector); @@ -257,11 +257,11 @@ public void testSharedCache() { for (int i = 0; i < 10; i++) { entityCache.get(modelId3, detector2); } - modelState3 = new ModelState<>( - new EntityModel(entity3, new ArrayDeque<>(), null), + modelState3 = new ADModelState<>( + new createFromValueOnlySamples(entity3, new ArrayDeque<>(), null), modelId3, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); @@ -273,11 +273,11 @@ public void testSharedCache() { // replace modelId2 in shared cache entityCache.get(modelId4, detector2); } - modelState4 = new ModelState<>( - new EntityModel(entity4, new ArrayDeque<>(), null), + modelState4 = new ADModelState<>( + new createFromValueOnlySamples(entity4, new ArrayDeque<>(), null), modelId4, detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); @@ -301,7 +301,7 @@ public void testReplace() { entityCache.hostIfPossible(detector, modelState1); assertEquals(1, entityCache.getActiveEntities(detectorId)); when(memoryTracker.canAllocate(anyLong())).thenReturn(false); - ModelState state = null; + ADModelState> state = null; for (int i = 0; i < 4; i++) { entityCache.get(modelId2, detector); @@ -364,7 +364,7 @@ public void testClear() { assertEquals(2, entityCache.getTotalActiveEntities()); assertTrue(entityCache.isActive(detectorId, modelId1)); assertEquals(0, entityCache.getTotalUpdates(detectorId)); - modelState1.getModel().addSample(point); + modelState1.getModel().addValueOnlySample(point); assertEquals(1, entityCache.getTotalUpdates(detectorId)); assertEquals(1, entityCache.getTotalUpdates(detectorId, modelId1)); entityCache.clear(detectorId); @@ -410,7 +410,7 @@ public void testSuccessfulConcurrentMaintenance() { doAnswer(invovacation -> { inProgressLatch.await(100, TimeUnit.SECONDS); return null; - }).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(MemoryTracker.Origin.class)); + }).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(ADMemoryTracker.Origin.class)); doAnswer(invocation -> { inProgressLatch.countDown(); @@ -448,12 +448,12 @@ public void testFailedConcurrentMaintenance() throws InterruptedException { final CountDownLatch scheduleCountDown = new CountDownLatch(1); final CountDownLatch scheduledThreadCountDown = new CountDownLatch(1); - doThrow(NullPointerException.class).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(MemoryTracker.Origin.class)); + doThrow(NullPointerException.class).when(memoryTracker).releaseMemory(anyLong(), anyBoolean(), any(ADMemoryTracker.Origin.class)); doAnswer(invovacation -> { scheduleCountDown.await(100, TimeUnit.SECONDS); return null; - }).when(memoryTracker).syncMemoryState(any(MemoryTracker.Origin.class), anyLong(), anyLong()); + }).when(memoryTracker).syncMemoryState(any(ADMemoryTracker.Origin.class), anyLong(), anyLong()); AtomicReference runnable = new AtomicReference(); doAnswer(invocation -> { @@ -536,19 +536,19 @@ public void testSelectToReplaceInCache() { private void replaceInOtherCacheSetUp() { Entity entity5 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal5"); Entity entity6 = Entity.createSingleAttributeEntity("attributeName1", "attributeVal6"); - ModelState modelState5 = new ModelState<>( - new EntityModel(entity5, new ArrayDeque<>(), null), + ADModelState> modelState5 = new ADModelState<>( + new createFromValueOnlySamples(entity5, new ArrayDeque<>(), null), entity5.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); - ModelState modelState6 = new ModelState<>( - new EntityModel(entity6, new ArrayDeque<>(), null), + ADModelState> modelState6 = new ADModelState<>( + new createFromValueOnlySamples(entity6, new ArrayDeque<>(), null), entity6.getModelId(detectorId2).get(), detectorId2, - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, 0 ); @@ -658,7 +658,7 @@ public void testLongDetectorInterval() { String modelId = entity1.getModelId(detectorId).get(); // record last access time 1000 assertTrue(null == entityCache.get(modelId, detector)); - assertEquals(-1, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(-1, entityCache.getLastActiveTime(detectorId, modelId)); // 2 hour = 7200 seconds have passed long currentTimeEpoch = 8200; when(clock.instant()).thenReturn(Instant.ofEpochSecond(currentTimeEpoch)); @@ -667,7 +667,7 @@ public void testLongDetectorInterval() { // door keeper still has the record and won't blocks entity state being created entityCache.get(modelId, detector); // * 1000 to convert to milliseconds - assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveMs(detectorId, modelId)); + assertEquals(currentTimeEpoch * 1000, entityCache.getLastActiveTime(detectorId, modelId)); } finally { ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.DOOR_KEEPER_IN_CACHE_ENABLED, false); } @@ -723,7 +723,7 @@ public void testRemoveEntityModel() { assertTrue(null != entityCache.get(entity2.getModelId(detectorId).get(), detector)); - entityCache.removeEntityModel(detectorId, entity2.getModelId(detectorId).get()); + entityCache.removeModel(detectorId, entity2.getModelId(detectorId).get()); assertTrue(null == entityCache.get(entity2.getModelId(detectorId).get(), detector)); diff --git a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java index 4e721d68e..09cc23bd6 100644 --- a/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java +++ b/src/test/java/org/opensearch/ad/caching/PriorityTrackerTests.java @@ -21,6 +21,7 @@ import org.junit.Before; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.caching.PriorityTracker; public class PriorityTrackerTests extends OpenSearchTestCase { Clock clock; diff --git a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java index 415ec75fe..9d7e60c6e 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADClusterEventListenerTests.java @@ -38,6 +38,8 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.gateway.GatewayService; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterEventListener; +import org.opensearch.timeseries.cluster.HashRing; public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterManagerNodeId = "clusterManagerNode"; @@ -45,7 +47,7 @@ public class ADClusterEventListenerTests extends AbstractTimeSeriesTest { private final String clusterName = "multi-node-cluster"; private ClusterService clusterService; - private ADClusterEventListener listener; + private ClusterEventListener listener; private HashRing hashRing; private ClusterState oldClusterState; private ClusterState newClusterState; @@ -66,7 +68,7 @@ public static void tearDownAfterClass() { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(ADClusterEventListener.class); + super.setUpLog4jForJUnit(ClusterEventListener.class); clusterService = createClusterService(threadPool); hashRing = mock(HashRing.class); @@ -98,7 +100,7 @@ public void setUp() throws Exception { ) .build(); - listener = new ADClusterEventListener(clusterService, hashRing); + listener = new ClusterEventListener(clusterService, hashRing); } @Override @@ -114,7 +116,7 @@ public void tearDown() throws Exception { public void testUnchangedClusterState() { listener.clusterChanged(new ClusterChangedEvent("foo", oldClusterState, oldClusterState)); - assertTrue(!testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(!testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); } public void testIsWarmNode() { @@ -134,7 +136,7 @@ public void testIsWarmNode() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", warmNodeClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } public void testNotRecovered() { @@ -150,7 +152,7 @@ public void testNotRecovered() { .blocks(ClusterBlocks.builder().addGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) .build(); listener.clusterChanged(new ClusterChangedEvent("foo", blockedClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NOT_RECOVERED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NOT_RECOVERED_MSG)); } class ListenerRunnable implements Runnable { @@ -170,7 +172,7 @@ public void testInProgress() { }).when(hashRing).buildCircles(any(), any()); new Thread(new ListenerRunnable()).start(); listener.clusterChanged(new ClusterChangedEvent("bar", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.IN_PROGRESS_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.IN_PROGRESS_MSG)); } public void testNodeAdded() { @@ -182,10 +184,10 @@ public void testNodeAdded() { doAnswer(invocation -> Optional.of(clusterManagerNode)) .when(hashRing) - .getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class)); + .getOwningNodeWithSameLocalVersionForRealtime(any(String.class)); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, oldClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: false, node added: true")); } @@ -203,7 +205,7 @@ public void testNodeRemoved() { .build(); listener.clusterChanged(new ClusterChangedEvent("foo", newClusterState, twoDataNodeClusterState)); - assertTrue(testAppender.containsMessage(ADClusterEventListener.NODE_CHANGED_MSG)); + assertTrue(testAppender.containsMessage(ClusterEventListener.NODE_CHANGED_MSG)); assertTrue(testAppender.containsMessage("node removed: true, node added: true")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java index aa5fcc55b..79f1cd26d 100644 --- a/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ADVersionUtilTests.java @@ -13,22 +13,23 @@ import org.opensearch.Version; import org.opensearch.ad.ADUnitTestCase; +import org.opensearch.timeseries.cluster.VersionUtil; public class ADVersionUtilTests extends ADUnitTestCase { public void testParseVersionFromString() { - Version version = ADVersionUtil.fromString("2.1.0.0"); + Version version = VersionUtil.fromString("2.1.0.0"); assertEquals(Version.V_2_1_0, version); - version = ADVersionUtil.fromString("2.1.0"); + version = VersionUtil.fromString("2.1.0"); assertEquals(Version.V_2_1_0, version); } public void testParseVersionFromStringWithNull() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString(null)); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString(null)); } public void testParseVersionFromStringWithWrongFormat() { - expectThrows(IllegalArgumentException.class, () -> ADVersionUtil.fromString("1.1")); + expectThrows(IllegalArgumentException.class, () -> VersionUtil.fromString("1.1")); } } diff --git a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java index 637c5e10e..58fdadf40 100644 --- a/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java +++ b/src/test/java/org/opensearch/ad/cluster/ClusterManagerEventListenerTests.java @@ -30,7 +30,6 @@ import org.opensearch.ad.cluster.diskcleanup.ModelCheckpointIndexRetention; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.component.LifecycleListener; @@ -40,6 +39,9 @@ import org.opensearch.threadpool.Scheduler.Cancellable; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.ClusterManagerEventListener; +import org.opensearch.timeseries.cluster.HourlyCron; +import org.opensearch.timeseries.util.ClientUtil; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class ClusterManagerEventListenerTests extends AbstractTimeSeriesTest { @@ -60,7 +62,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); ClusterSettings settings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.CHECKPOINT_TTL))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_CHECKPOINT_TTL))) ); when(clusterService.getClusterSettings()).thenReturn(settings); @@ -85,7 +87,7 @@ public void setUp() throws Exception { clock, clientUtil, nodeFilter, - AnomalyDetectorSettings.CHECKPOINT_TTL, + AnomalyDetectorSettings.AD_CHECKPOINT_TTL, Settings.EMPTY ); } diff --git a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java index 63d48ef3c..5f7ecefa5 100644 --- a/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/DailyCronTests.java @@ -24,11 +24,11 @@ import org.opensearch.OpenSearchException; import org.opensearch.action.ActionListener; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.reindex.BulkByScrollResponse; import org.opensearch.index.reindex.DeleteByQueryAction; -import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.DailyCron; +import org.opensearch.timeseries.util.ClientUtil; public class DailyCronTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java index e85051dd9..f5bb8c2d2 100644 --- a/src/test/java/org/opensearch/ad/cluster/HashRingTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HashRingTests.java @@ -19,7 +19,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.COOLDOWN_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_COOLDOWN_MINUTES; import java.net.UnknownHostException; import java.time.Clock; @@ -38,7 +38,6 @@ import org.opensearch.action.admin.cluster.node.info.PluginsAndModules; import org.opensearch.ad.ADUnitTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; @@ -50,6 +49,8 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.plugins.PluginInfo; +import org.opensearch.timeseries.cluster.ADDataMigrator; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -74,7 +75,7 @@ public class HashRingTests extends ADUnitTestCase { private DiscoveryNode localNode; private DiscoveryNode newNode; private DiscoveryNode warmNode; - private ModelManager modelManager; + private ADModelManager modelManager; @Override @Before @@ -88,8 +89,8 @@ public void setUp() throws Exception { warmNodeId = "warmNode"; warmNode = createNode(warmNodeId, "127.0.0.3", 9202, ImmutableMap.of(ADCommonName.BOX_TYPE_KEY, ADCommonName.WARM_BOX_TYPE)); - settings = Settings.builder().put(COOLDOWN_MINUTES.getKey(), TimeValue.timeValueSeconds(5)).build(); - ClusterSettings clusterSettings = clusterSetting(settings, COOLDOWN_MINUTES); + settings = Settings.builder().put(AD_COOLDOWN_MINUTES.getKey(), TimeValue.timeValueSeconds(5)).build(); + ClusterSettings clusterSettings = clusterSetting(settings, AD_COOLDOWN_MINUTES); clusterService = spy(new ClusterService(settings, clusterSettings, null)); nodeFilter = spy(new DiscoveryNodeFilterer(clusterService)); @@ -107,7 +108,7 @@ public void setUp() throws Exception { when(adminClient.cluster()).thenReturn(clusterAdminClient); String modelId = "123_model_threshold"; - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); doAnswer(invocation -> { Set res = new HashSet<>(); res.add(modelId); @@ -121,7 +122,7 @@ public void testGetOwningNodeWithEmptyResult() throws UnknownHostException { DiscoveryNode node1 = createNode(Integer.toString(1), "127.0.0.4", 9204, emptyMap()); doReturn(node1).when(clusterService).localNode(); - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertFalse(node.isPresent()); } @@ -130,10 +131,10 @@ public void testGetOwningNode() throws UnknownHostException { // Add first node, hashRing.buildCircles(delta, ActionListener.wrap(r -> { - Optional node = hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD("http-latency-rcf-1"); + Optional node = hashRing.getOwningNodeWithSameLocalVersionForRealtime("http-latency-rcf-1"); assertTrue(node.isPresent()); assertTrue(asList(newNodeId, localNodeId).contains(node.get().getId())); - DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalAdVersion(); + DiscoveryNode[] nodesWithSameLocalAdVersion = hashRing.getNodesWithSameLocalVersion(); Set nodesWithSameLocalAdVersionIds = new HashSet<>(); for (DiscoveryNode n : nodesWithSameLocalAdVersion) { nodesWithSameLocalAdVersionIds.add(n.getId()); @@ -143,10 +144,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 2, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will change as it's eligible to build for when its empty - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Build hash ring failed", true); @@ -162,10 +163,10 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 3, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); // Circles for realtime AD will not change as it's eligible to rebuild - assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 2, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); @@ -183,9 +184,9 @@ public void testGetOwningNode() throws UnknownHostException { assertEquals( "Wrong hash ring size for historical analysis", 4, - hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, false).size() + hashRing.getNodesWithSameVersion(Version.V_2_1_0, false).size() ); - assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameAdVersion(Version.V_2_1_0, true).size()); + assertEquals("Wrong hash ring size for realtime AD", 4, hashRing.getNodesWithSameVersion(Version.V_2_1_0, true).size()); }, e -> { logger.error("building hash ring failed", e); assertFalse("Failed to build hash ring", true); @@ -194,7 +195,7 @@ public void testGetOwningNode() throws UnknownHostException { public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { setupNodeDelta(); - hashRing.getAllEligibleDataNodesWithKnownAdVersion(nodes -> { + hashRing.getAllEligibleDataNodesWithKnownVersion(nodes -> { assertEquals("Wrong hash ring size for historical analysis", 2, nodes.length); Optional node = hashRing.getNodeByAddress(newNode.getAddress()); assertTrue(node.isPresent()); @@ -205,7 +206,7 @@ public void testGetAllEligibleDataNodesWithKnownAdVersionAndGetNodeByAddress() { public void testBuildAndGetOwningNodeWithSameLocalAdVersion() { setupNodeDelta(); hashRing - .buildAndGetOwningNodeWithSameLocalAdVersion( + .buildAndGetOwningNodeWithSameLocalVersion( "testModelId", node -> { assertTrue(node.isPresent()); }, ActionListener.wrap(r -> {}, e -> { assertFalse("Failed to build hash ring", true); }) diff --git a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java index 6461a7b3e..15efdd6e4 100644 --- a/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java +++ b/src/test/java/org/opensearch/ad/cluster/HourlyCronTests.java @@ -39,6 +39,7 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.cluster.HourlyCron; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import test.org.opensearch.ad.util.ClusterCreation; diff --git a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java index 1425a5ec3..23c17ae23 100644 --- a/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java +++ b/src/test/java/org/opensearch/ad/cluster/diskcleanup/IndexCleanupTests.java @@ -29,7 +29,6 @@ import org.opensearch.action.admin.indices.stats.CommonStats; import org.opensearch.action.admin.indices.stats.IndicesStatsResponse; import org.opensearch.action.admin.indices.stats.ShardStats; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.client.IndicesAdminClient; import org.opensearch.cluster.service.ClusterService; @@ -38,6 +37,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.store.StoreStats; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.util.ClientUtil; public class IndexCleanupTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java index 919b3e068..d9b5dc9da 100644 --- a/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java +++ b/src/test/java/org/opensearch/ad/e2e/AbstractSyntheticDataTest.java @@ -11,9 +11,9 @@ package org.opensearch.ad.e2e; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import static org.opensearch.timeseries.TestHelpers.toHttpEntity; +import static org.opensearch.ad.TestHelpers.toHttpEntity; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.io.File; import java.io.FileReader; @@ -61,8 +61,8 @@ protected void disableResourceNotFoundFaultTolerence() throws IOException { settingCommand.startObject(); settingCommand.startObject("persistent"); - settingCommand.field(MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); - settingCommand.field(BACKOFF_MINUTES.getKey(), 0); + settingCommand.field(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.getKey(), 100_000); + settingCommand.field(AD_BACKOFF_MINUTES.getKey(), 0); settingCommand.endObject(); settingCommand.endObject(); Request request = new Request("PUT", "/_cluster/settings"); diff --git a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java index b5ce70d05..29cc87058 100644 --- a/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeatureManagerTests.java @@ -63,6 +63,10 @@ import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.Features; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.feature.SinglePointFeatures; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; diff --git a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java index 447bdd6c4..fcba30fa6 100644 --- a/src/test/java/org/opensearch/ad/feature/FeaturesTests.java +++ b/src/test/java/org/opensearch/ad/feature/FeaturesTests.java @@ -23,6 +23,7 @@ import org.junit.Test; import org.junit.runner.RunWith; +import org.opensearch.timeseries.feature.Features; @RunWith(JUnitParamsRunner.class) public class FeaturesTests { diff --git a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java index 1d0da6d19..9e687b437 100644 --- a/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/NoPowermockSearchFeatureDaoTests.java @@ -55,10 +55,10 @@ import org.opensearch.action.search.SearchResponse.Clusters; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.ADNodeStateManager; +import org.opensearch.ad.TestHelpers; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.lease.Releasables; @@ -100,6 +100,7 @@ import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.carrotsearch.hppc.BitMixer; import com.google.common.collect.ImmutableList; @@ -161,18 +162,18 @@ public void setUp() throws Exception { Settings.EMPTY, Collections .unmodifiableSet( - new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.PAGE_SIZE)) + new HashSet<>(Arrays.asList(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.AD_PAGE_SIZE)) ) ); clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); clock = mock(Clock.class); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = new SearchFeatureDao( @@ -182,7 +183,7 @@ public void setUp() throws Exception { clientUtil, settings, clusterService, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, clock, 1, 1, @@ -369,7 +370,7 @@ public void testGetHighestCountEntitiesExhaustedPages() throws InterruptedExcept clientUtil, settings, clusterService, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, clock, 2, 1, @@ -415,7 +416,7 @@ public void testGetHighestCountEntitiesNotEnoughTime() throws InterruptedExcepti clientUtil, settings, clusterService, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, clock, 2, 1, diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java index e00225ef0..88b42dedd 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoParamTests.java @@ -54,7 +54,6 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -80,6 +79,7 @@ import org.opensearch.timeseries.dataprocessor.Imputer; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; @@ -178,7 +178,7 @@ public void setup() throws Exception { }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = spy( - new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, TimeSeriesSettings.NUM_SAMPLES_PER_TREE) ); detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); diff --git a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java index cf18b2fdd..36ee5044c 100644 --- a/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java +++ b/src/test/java/org/opensearch/ad/feature/SearchFeatureDaoTests.java @@ -57,10 +57,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.time.DateFormatter; @@ -96,6 +93,7 @@ import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.ParseUtils; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PowerMockIgnore; @@ -135,7 +133,7 @@ public class SearchFeatureDaoTests { @Mock private Max max; @Mock - private NodeStateManager stateManager; + private ADNodeStateManager stateManager; @Mock private AnomalyDetector detector; @@ -173,15 +171,15 @@ public void setup() throws Exception { settings = Settings.EMPTY; when(client.threadPool()).thenReturn(threadPool); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, settings); searchFeatureDao = spy( - new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + new SearchFeatureDao(client, xContent, imputer, clientUtil, settings, null, TimeSeriesSettings.NUM_SAMPLES_PER_TREE) ); detectionInterval = new IntervalTimeConfiguration(1, ChronoUnit.MINUTES); @@ -378,7 +376,7 @@ public void testGetEntityMinDataTime() { ActionListener> listener = mock(ActionListener.class); Entity entity = Entity.createSingleAttributeEntity("field", "app_1"); - searchFeatureDao.getEntityMinDataTime(detector, entity, listener); + searchFeatureDao.getMinDataTime(detector, entity, listener); ArgumentCaptor> captor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(captor.capture()); diff --git a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java index 313800385..eaf0d1d8a 100644 --- a/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java +++ b/src/test/java/org/opensearch/ad/indices/AnomalyDetectionIndicesTests.java @@ -23,8 +23,6 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; -import org.opensearch.timeseries.indices.IndexManagementIntegTestCase; -import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; public class AnomalyDetectionIndicesTests extends IndexManagementIntegTestCase { diff --git a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java index 53bea9015..7c9694a41 100644 --- a/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java +++ b/src/test/java/org/opensearch/ad/indices/CustomIndexTests.java @@ -33,7 +33,6 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; -import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -188,7 +187,7 @@ private Map createMapping() { attribution_nested_mapping.put("feature_id", Collections.singletonMap("type", "keyword")); mappings.put(AnomalyResult.RELEVANT_ATTRIBUTION_FIELD, attribution_mapping); - mappings.put(CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); + mappings.put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, Collections.singletonMap("type", "integer")); mappings.put(CommonName.TASK_ID_FIELD, Collections.singletonMap("type", "keyword")); diff --git a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java index f53393014..fb65f45af 100644 --- a/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java +++ b/src/test/java/org/opensearch/ad/indices/UpdateMappingTests.java @@ -55,7 +55,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.timeseries.AbstractTimeSeriesTest; -import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; @@ -168,7 +167,7 @@ public void testUpdateMapping() throws IOException { put(ADIndexManagement.META, new HashMap() { { // version 1 will cause update - put(CommonName.SCHEMA_VERSION_FIELD, 1); + put(org.opensearch.timeseries.constant.CommonName.SCHEMA_VERSION_FIELD, 1); } }); } diff --git a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java index e7d74f28a..7ca486e4a 100644 --- a/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java +++ b/src/test/java/org/opensearch/ad/ml/AbstractCosineDataTest.java @@ -15,9 +15,8 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; import java.time.Clock; import java.time.Instant; @@ -33,14 +32,9 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; @@ -59,7 +53,9 @@ import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.common.collect.ImmutableList; public class AbstractCosineDataTest extends AbstractTimeSeriesTest { @@ -67,11 +63,11 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { String modelId; String entityName; String detectorId; - ModelState modelState; + ADModelState> modelState; Clock clock; float priority; - EntityColdStarter entityColdStarter; - NodeStateManager stateManager; + ADEntityColdStart entityColdStarter; + ADNodeStateManager stateManager; SearchFeatureDao searchFeatureDao; Imputer imputer; CheckpointDao checkpoint; @@ -82,11 +78,11 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { Runnable releaseSemaphore; ActionListener listener; CountDownLatch inProgressLatch; - CheckpointWriteWorker checkpointWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; Entity entity; AnomalyDetector detector; long rcfSeed; - ModelManager modelManager; + ADModelManager modelManager; ClientUtil clientUtil; ClusterService clusterService; ClusterSettings clusterSettings; @@ -97,7 +93,7 @@ public class AbstractCosineDataTest extends AbstractTimeSeriesTest { @Override public void setUp() throws Exception { super.setUp(); - numMinSamples = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + numMinSamples = TimeSeriesSettings.NUM_MIN_SAMPLES; clock = mock(Clock.class); when(clock.instant()).thenReturn(Instant.now()); @@ -125,8 +121,8 @@ public void setUp() throws Exception { }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); nodestateSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - nodestateSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - nodestateSetting.add(BACKOFF_MINUTES); + nodestateSetting.add(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE); + nodestateSetting.add(AD_BACKOFF_MINUTES); nodestateSetting.add(CHECKPOINT_SAVING_FREQ); clusterSettings = new ClusterSettings(Settings.EMPTY, nodestateSetting); @@ -140,60 +136,58 @@ public void setUp() throws Exception { clusterService = ClusterServiceUtils.createClusterService(threadPool, discoveryNode, clusterSettings); - stateManager = new NodeStateManager( + stateManager = new ADNodeStateManager( client, xContentRegistry(), settings, clientUtil, clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, clusterService ); imputer = new LinearUniformImputer(true); searchFeatureDao = mock(SearchFeatureDao.class); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); featureManager = new FeatureManager( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); rcfSeed = 2051L; - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, imputer, searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS ); detectorId = "123"; @@ -211,20 +205,20 @@ public void setUp() throws Exception { }; listener = ActionListener.wrap(releaseSemaphore); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, entityColdStarter, mock(FeatureManager.class), - mock(MemoryTracker.class), + mock(ADMemoryTracker.class), settings, clusterService ); diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java index 8c3e6c472..6e3cf993f 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java @@ -96,14 +96,13 @@ import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import test.org.opensearch.ad.util.JsonDeserializer; import test.org.opensearch.ad.util.MLUtil; @@ -127,7 +126,7 @@ public class CheckpointDaoTests extends OpenSearchTestCase { private static final Logger logger = LogManager.getLogger(CheckpointDaoTests.class); - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; // dependencies @Mock(answer = Answers.RETURNS_DEEP_STUBS) @@ -194,7 +193,7 @@ public GenericObjectPool run() { return new GenericObjectPool<>(new BasePooledObjectFactory() { @Override public LinkedBuffer create() throws Exception { - return LinkedBuffer.allocate(AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES); + return LinkedBuffer.allocate(TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES); } @Override @@ -204,14 +203,14 @@ public PooledObject wrap(LinkedBuffer obj) { }); } })); - serializeRCFBufferPool.setMaxTotal(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); - serializeRCFBufferPool.setMaxIdle(AnomalyDetectorSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxTotal(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); + serializeRCFBufferPool.setMaxIdle(TimeSeriesSettings.MAX_TOTAL_RCF_SERIALIZATION_BUFFERS); serializeRCFBufferPool.setMinIdle(0); serializeRCFBufferPool.setBlockWhenExhausted(false); - serializeRCFBufferPool.setTimeBetweenEvictionRuns(AnomalyDetectorSettings.HOURLY_MAINTENANCE); + serializeRCFBufferPool.setTimeBetweenEvictionRuns(TimeSeriesSettings.HOURLY_MAINTENANCE); anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -224,7 +223,7 @@ public PooledObject wrap(LinkedBuffer obj) { indexUtil, maxCheckpointBytes, serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); @@ -492,14 +491,15 @@ public void test_deleteModelCheckpoint_callListener_whenCompleted() { @SuppressWarnings("unchecked") public void test_restore() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - EntityModel modelToSave = state.getModel(); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + createFromValueOnlySamples modelToSave = state.getModel(); GetResponse getResponse = mock(GetResponse.class); when(getResponse.isExists()).thenReturn(true); Map source = new HashMap<>(); - source.put(CheckpointDao.DETECTOR_ID, state.getId()); - source.put(CheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); + source.put(ADCheckpointDao.DETECTOR_ID, state.getConfigId()); + source.put(ADCheckpointDao.FIELD_MODELV2, checkpointDao.toCheckpoint(modelToSave, modelId).get()); source.put(CommonName.TIMESTAMP, "2020-10-11T22:58:23.610392Z"); when(getResponse.getSource()).thenReturn(source); @@ -510,14 +510,14 @@ public void test_restore() throws IOException { return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(BiConsumer.class), any(ActionListener.class)); - ActionListener>> listener = mock(ActionListener.class); + ActionListener>> listener = mock(ActionListener.class); checkpointDao.deserializeModelCheckpoint(modelId, listener); - ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); + ArgumentCaptor>> responseCaptor = ArgumentCaptor.forClass(Optional.class); verify(listener).onResponse(responseCaptor.capture()); - Optional> response = responseCaptor.getValue(); + Optional> response = responseCaptor.getValue(); assertTrue(response.isPresent()); - Entry entry = response.get(); + Entry entry = response.get(); OffsetDateTime utcTime = entry.getValue().atOffset(ZoneOffset.UTC); assertEquals(2020, utcTime.getYear()); assertEquals(Month.OCTOBER, utcTime.getMonth()); @@ -526,9 +526,9 @@ public void test_restore() throws IOException { assertEquals(58, utcTime.getMinute()); assertEquals(23, utcTime.getSecond()); - EntityModel model = entry.getKey(); - Queue queue = model.getSamples(); - Queue samplesToSave = modelToSave.getSamples(); + createFromValueOnlySamples model = entry.getKey(); + Queue queue = model.getValueOnlySamples(); + Queue samplesToSave = modelToSave.getValueOnlySamples(); assertEquals(samplesToSave.size(), queue.size()); assertTrue(Arrays.equals(samplesToSave.peek(), queue.peek())); logger.info(modelToSave.getTrcf()); @@ -675,7 +675,7 @@ public void test_batch_read() throws InterruptedException { } public void test_too_large_checkpoint() throws IOException { - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -688,17 +688,19 @@ public void test_too_large_checkpoint() throws IOException { indexUtil, 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); assertTrue(checkpointDao.toIndexSource(state).isEmpty()); } public void test_to_index_source() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); Map source = checkpointDao.toIndexSource(state); assertTrue(!source.isEmpty()); @@ -712,7 +714,7 @@ public void test_to_index_source() throws IOException { public void testBorrowFromPoolFailure() throws Exception { GenericObjectPool mockSerializeRCFBufferPool = mock(GenericObjectPool.class); when(mockSerializeRCFBufferPool.borrowObject()).thenThrow(NoSuchElementException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -725,11 +727,12 @@ public void testBorrowFromPoolFailure() throws Exception { indexUtil, 1, // make the max checkpoint size 1 byte only mockSerializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); assertTrue(!checkpointDao.toCheckpoint(state.getModel(), modelId).get().isEmpty()); } @@ -737,7 +740,7 @@ public void testMapperFailure() throws IOException { ThresholdedRandomCutForestMapper mockMapper = mock(ThresholdedRandomCutForestMapper.class); when(mockMapper.toState(any())).thenThrow(RuntimeException.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -750,42 +753,45 @@ public void testMapperFailure() throws IOException { indexUtil, 1, // make the max checkpoint size 1 byte only serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); // make sure sample size is not 0 otherwise sample size won't be written to checkpoint - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(1).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertEquals(null, JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertEquals(null, JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); assertTrue(null != JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); // assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); } public void testEmptySample() throws IOException { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); assertEquals(null, JsonDeserializer.getChildNode(json, CommonName.ENTITY_SAMPLE)); // assertTrue(null != JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_THRESHOLD)); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointErcfCheckoutFail() throws Exception { when(serializeRCFBufferPool.borrowObject()).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } @SuppressWarnings("unchecked") private void setUpMockTrcf() { trcfMapper = mock(ThresholdedRandomCutForestMapper.class); trcfSchema = mock(Schema.class); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -798,7 +804,7 @@ private void setUpMockTrcf() { indexUtil, maxCheckpointBytes, serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); } @@ -807,10 +813,11 @@ public void testToCheckpointTrcfCheckoutBufferFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfFailNewBuffer() throws Exception { @@ -818,10 +825,11 @@ public void testToCheckpointTrcfFailNewBuffer() throws Exception { doReturn(null).when(serializeRCFBufferPool).borrowObject(); when(trcfMapper.toState(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception { @@ -829,41 +837,44 @@ public void testToCheckpointTrcfCheckoutBufferInvalidateFail() throws Exception when(trcfMapper.toState(any())).thenThrow(RuntimeException.class).thenReturn(null); doThrow(RuntimeException.class).when(serializeRCFBufferPool).invalidateObject(any()); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).sampleSize(0).build()); String json = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); - assertNotNull(JsonDeserializer.getChildNode(json, CheckpointDao.ENTITY_TRCF)); + assertNotNull(JsonDeserializer.getChildNode(json, ADCheckpointDao.ENTITY_TRCF)); } public void testFromEntityModelCheckpointWithTrcf() throws Exception { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); Map entity = new HashMap<>(); entity.put(FIELD_MODELV2, model); entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); + Entry pair = result.get(); + createFromValueOnlySamples entityModel = pair.getKey(); assertTrue(entityModel.getTrcf().isPresent()); } public void testFromEntityModelCheckpointTrcfMapperFail() throws Exception { setUpMockTrcf(); when(trcfMapper.toModel(any())).thenThrow(RuntimeException.class); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); String model = checkpointDao.toCheckpoint(state.getModel(), modelId).get(); Map entity = new HashMap<>(); entity.put(FIELD_MODELV2, model); entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); + Entry pair = result.get(); + createFromValueOnlySamples entityModel = pair.getKey(); assertFalse(entityModel.getTrcf().isPresent()); } @@ -892,14 +903,15 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); Instant now = modelPair.getRight(); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + Optional> result = checkpointDao + .fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); assertTrue(result.isPresent()); - Entry pair = result.get(); + Entry pair = result.get(); assertEquals(now, pair.getValue()); - EntityModel entityModel = pair.getKey(); + createFromValueOnlySamples entityModel = pair.getKey(); - Queue samples = entityModel.getSamples(); + Queue samples = entityModel.getValueOnlySamples(); assertEquals(6, samples.size()); double[] firstSample = samples.peek(); assertEquals(1, firstSample.length); @@ -924,7 +936,7 @@ public void testFromEntityModelCheckpointBWC() throws FileNotFoundException, IOE public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_2.json"); - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, indexName, @@ -937,10 +949,11 @@ public void testFromEntityModelCheckpointModelTooLarge() throws FileNotFoundExce indexUtil, 100_000, // checkpoint_2.json is of 224603 bytes. serializeRCFBufferPool, - AnomalyDetectorSettings.SERIALIZATION_BUFFER_BYTES, + TimeSeriesSettings.SERIALIZATION_BUFFER_BYTES, anomalyRate ); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + Optional> result = checkpointDao + .fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); // checkpoint is only configured to take in 1 MB checkpoint at most. But the checkpoint here is of 1408047 bytes. assertTrue(!result.isPresent()); } @@ -950,28 +963,31 @@ public void testFromEntityModelCheckpointEmptyModel() throws FileNotFoundExcepti Map entity = new HashMap<>(); entity.put(CommonName.TIMESTAMP, Instant.now().toString()); - Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); + Optional> result = checkpointDao.fromEntityModelCheckpoint(entity, this.modelId); assertTrue(!result.isPresent()); } public void testFromEntityModelCheckpointEmptySamples() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_1.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + Optional> result = checkpointDao + .fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); assertTrue(result.isPresent()); - Queue samples = result.get().getKey().getSamples(); + Queue samples = result.get().getKey().getValueOnlySamples(); assertEquals(0, samples.size()); } public void testFromEntityModelCheckpointNoRCF() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_3.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + Optional> result = checkpointDao + .fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); assertTrue(result.isPresent()); assertTrue(!result.get().getKey().getTrcf().isPresent()); } public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundException, IOException, URISyntaxException { Pair, Instant> modelPair = setUp1_0Model("checkpoint_4.json"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); + Optional> result = checkpointDao + .fromEntityModelCheckpoint(modelPair.getLeft(), this.modelId); assertTrue(result.isPresent()); ThresholdedRandomCutForest trcf = result.get().getKey().getTrcf().get(); @@ -984,17 +1000,17 @@ public void testFromEntityModelCheckpointNoThreshold() throws FileNotFoundExcept } public void testFromEntityModelCheckpointWithEntity() throws Exception { - ModelState state = MLUtil + ADModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).entityAttributes(true).build()); Map content = checkpointDao.toIndexSource(state); // Opensearch will convert from java.time.ZonedDateTime to String. Here I am converting to simulate that content.put(CommonName.TIMESTAMP, "2021-09-23T05:00:37.93195Z"); - Optional> result = checkpointDao.fromEntityModelCheckpoint(content, this.modelId); + Optional> result = checkpointDao.fromEntityModelCheckpoint(content, this.modelId); assertTrue(result.isPresent()); - Entry pair = result.get(); - EntityModel entityModel = pair.getKey(); + Entry pair = result.get(); + createFromValueOnlySamples entityModel = pair.getKey(); assertTrue(entityModel.getEntity().isPresent()); assertEquals(state.getModel().getEntity().get(), entityModel.getEntity().get()); } diff --git a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java index c94c145cb..8571adb4a 100644 --- a/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java +++ b/src/test/java/org/opensearch/ad/ml/CheckpointDeleteTests.java @@ -29,7 +29,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.util.ClientUtil; import org.opensearch.client.Client; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.reindex.BulkByScrollResponse; @@ -60,7 +59,7 @@ private enum DeleteExecutionMode { PARTIAL_FAILURE } - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; private Client client; private ClientUtil clientUtil; private Gson gson; @@ -82,7 +81,7 @@ private enum DeleteExecutionMode { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(CheckpointDao.class); + super.setUpLog4jForJUnit(ADCheckpointDao.class); client = mock(Client.class); clientUtil = mock(ClientUtil.class); @@ -97,7 +96,7 @@ public void setUp() throws Exception { objectPool = mock(GenericObjectPool.class); int deserializeRCFBufferSize = 512; anomalyRate = 0.005; - checkpointDao = new CheckpointDao( + checkpointDao = new ADCheckpointDao( client, clientUtil, ADCommonName.CHECKPOINT_INDEX_NAME, @@ -157,7 +156,7 @@ public void delete_by_detector_id_template(DeleteExecutionMode mode) { return null; }).when(client).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); - checkpointDao.deleteModelCheckpointByDetectorId(detectorId); + checkpointDao.deleteModelCheckpointByConfigId(detectorId); } public void testDeleteSingleNormal() throws Exception { @@ -172,7 +171,7 @@ public void testDeleteSingleIndexNotFound() throws Exception { public void testDeleteSingleResultFailure() throws Exception { delete_by_detector_id_template(DeleteExecutionMode.FAILURE); - assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_LOG_MSG)); + assertTrue(testAppender.containsMessage(CheckpointDao.NOT_ABLE_TO_DELETE_CHECKPOINT_MSG)); } public void testDeleteSingleResultPartialFailure() throws Exception { diff --git a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java index 34265b0e6..17a8268e5 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityColdStarterTests.java @@ -43,9 +43,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.ADEnabledSetting; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -57,6 +54,8 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -103,11 +102,11 @@ public void tearDown() throws Exception { // train using samples directly public void testTrainUsingSamples() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(numMinSamples); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); entityColdStarter.trainModel(entity, detectorId, modelState, listener); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertTrue(model.getModel().isPresent()); + ThresholdedRandomCutForest ercf = model.getModel().get(); assertEquals(numMinSamples, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); @@ -115,14 +114,14 @@ public void testTrainUsingSamples() throws InterruptedException { public void testColdStart() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); @@ -142,24 +141,24 @@ public void testColdStart() throws InterruptedException, IOException { entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertTrue(model.getModel().isPresent()); + ThresholdedRandomCutForest ercf = model.getModel().get(); // 1 round: stride * (samples - 1) + 1 = 60 * 2 + 1 = 121 // plus 1 existing sample assertEquals(121, ercf.getForest().getTotalUpdates()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + model.getValueOnlySamples().size(), model.getValueOnlySamples().isEmpty()); checkSemaphoreRelease(); released.set(false); // too frequent cold start of the same detector will fail samples = MLUtil.createQueueSamples(1); - model = new EntityModel(entity, samples, null); + model = new createFromValueOnlySamples<>(entity, samples, null); entityColdStarter.trainModel(entity, detectorId, modelState, listener); - assertFalse(model.getTrcf().isPresent()); + assertFalse(model.getModel().isPresent()); // the samples is not touched since cold start does not happen - assertEquals("size: " + model.getSamples().size(), 1, model.getSamples().size()); + assertEquals("size: " + model.getValueOnlySamples().size(), 1, model.getValueOnlySamples().size()); checkSemaphoreRelease(); List expectedColdStartData = new ArrayList<>(); @@ -179,20 +178,20 @@ public void testColdStart() throws InterruptedException, IOException { // min max: miss one public void testMissMin() throws IOException, InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.empty()); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); entityColdStarter.trainModel(entity, detectorId, modelState, listener); verify(searchFeatureDao, never()).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(!model.getModel().isPresent()); checkSemaphoreRelease(); } @@ -201,7 +200,10 @@ public void testMissMin() throws IOException, InterruptedException { * @param modelState an initialized model state * @param coldStartData cold start data that initialized the modelState */ - private void diffTesting(ModelState modelState, List coldStartData) { + private void diffTesting( + ADModelState> modelState, + List coldStartData + ) { int inputDimension = detector.getEnabledFeatureIds().size(); ThresholdedRandomCutForest refTRcf = ThresholdedRandomCutForest @@ -210,16 +212,16 @@ private void diffTesting(ModelState modelState, List cold .dimensions(inputDimension * detector.getShingleSize()) .precision(Precision.FLOAT_32) .randomSeed(rcfSeed) - .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) + .numberOfTrees(TimeSeriesSettings.NUM_TREES) .shingleSize(detector.getShingleSize()) .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - .timeDecay(AnomalyDetectorSettings.TIME_DECAY) + .timeDecay(TimeSeriesSettings.TIME_DECAY) .outputAfter(numMinSamples) .initialAcceptFraction(0.125d) .parallelExecutionEnabled(false) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .internalShinglingEnabled(true) - .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + .anomalyRate(1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE) .build(); for (int i = 0; i < coldStartData.size(); i++) { @@ -237,8 +239,7 @@ private void diffTesting(ModelState modelState, List cold for (int i = 0; i < 100; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); AnomalyDescriptor descriptor = refTRcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + ThresholdingResult result = modelManager.getResult(point, modelState, modelId, entity, detector.getShingleSize()); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); } @@ -264,14 +265,14 @@ private List convertToFeatures(double[][] interval, int numValsToKeep) public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); double[] sample1 = new double[] { 57.0 }; @@ -291,11 +292,11 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(model.getModel().isPresent()); // 1 round: stride * (samples - 1) + 1 = 60 * 4 + 1 = 241 // if 241 < shingle size + numMinSamples, then another round is performed - assertEquals(241, modelState.getModel().getTrcf().get().getForest().getTotalUpdates()); + assertEquals(241, modelState.getModel().getModel().get().getForest().getTotalUpdates()); checkSemaphoreRelease(); List expectedColdStartData = new ArrayList<>(); @@ -309,7 +310,7 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc expectedColdStartData.addAll(convertToFeatures(interval2, 60)); double[][] interval3 = imputer.impute(new double[][] { new double[] { sample3[0], sample5[0] } }, 121); expectedColdStartData.addAll(convertToFeatures(interval3, 121)); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + model.getValueOnlySamples().size(), model.getValueOnlySamples().isEmpty()); assertEquals(241, expectedColdStartData.size()); diffTesting(modelState, expectedColdStartData); } @@ -317,15 +318,14 @@ public void testTwoSegmentsWithSingleSample() throws InterruptedException, IOExc // two segments of samples, one segment has 3 samples, while another one 2 samples public void testTwoSegments() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - double[] savedSample = samples.peek(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); double[] sample1 = new double[] { 57.0 }; @@ -348,8 +348,8 @@ public void testTwoSegments() throws InterruptedException, IOException { entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); - ThresholdedRandomCutForest ercf = model.getTrcf().get(); + assertTrue(model.getModel().isPresent()); + ThresholdedRandomCutForest ercf = model.getModel().get(); // 1 rounds: stride * (samples - 1) + 1 = 60 * 5 + 1 = 301 assertEquals(301, ercf.getForest().getTotalUpdates()); checkSemaphoreRelease(); @@ -368,40 +368,40 @@ public void testTwoSegments() throws InterruptedException, IOException { double[][] interval4 = imputer.impute(new double[][] { new double[] { sample5[0], sample6[0] } }, 61); expectedColdStartData.addAll(convertToFeatures(interval4, 61)); assertEquals(301, expectedColdStartData.size()); - assertTrue("size: " + model.getSamples().size(), model.getSamples().isEmpty()); + assertTrue("size: " + model.getValueOnlySamples().size(), model.getValueOnlySamples().isEmpty()); diffTesting(modelState, expectedColdStartData); } public void testThrottledColdStart() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onFailure(new OpenSearchRejectedExecutionException("")); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); entityColdStarter.trainModel(entity, detectorId, modelState, listener); entityColdStarter.trainModel(entity, "456", modelState, listener); // only the first one makes the call - verify(searchFeatureDao, times(1)).getEntityMinDataTime(any(), any(), any()); + verify(searchFeatureDao, times(1)).getMinDataTime(any(), any(), any()); checkSemaphoreRelease(); } public void testColdStartException() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onFailure(new TimeSeriesException(detectorId, "")); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); entityColdStarter.trainModel(entity, detectorId, modelState, listener); @@ -412,8 +412,8 @@ public void testColdStartException() throws InterruptedException { @SuppressWarnings("unchecked") public void testNotEnoughSamples() throws InterruptedException, IOException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); detector = TestHelpers.AnomalyDetectorBuilder .newInstance() @@ -432,7 +432,7 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); coldStartSamples.add(Optional.of(new double[] { 57.0 })); @@ -446,10 +446,10 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(!model.getModel().isPresent()); // 1st round we add 57 and 1. // 2nd round we add 57 and 1. - Queue currentSamples = model.getSamples(); + Queue currentSamples = model.getValueOnlySamples(); assertEquals("real sample size is " + currentSamples.size(), 4, currentSamples.size()); int j = 0; while (!currentSamples.isEmpty()) { @@ -467,8 +467,8 @@ public void testNotEnoughSamples() throws InterruptedException, IOException { @SuppressWarnings("unchecked") public void testEmptyDataRange() throws InterruptedException { Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); // the min-max range 894056973000L~894057860000L is too small and thus no data range can be found when(clock.millis()).thenReturn(894057860000L); @@ -485,14 +485,14 @@ public void testEmptyDataRange() throws InterruptedException { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(894056973000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(!model.getTrcf().isPresent()); + assertTrue(!model.getModel().isPresent()); // the min-max range is too small and thus no data range can be found - assertEquals("real sample size is " + model.getSamples().size(), 1, model.getSamples().size()); + assertEquals("real sample size is " + model.getValueOnlySamples().size(), 1, model.getValueOnlySamples().size()); } public void testTrainModelFromExistingSamplesEnoughSamples() { @@ -505,22 +505,22 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { .dimensions(dimensions) .precision(Precision.FLOAT_32) .randomSeed(rcfSeed) - .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) + .numberOfTrees(TimeSeriesSettings.NUM_TREES) .shingleSize(detector.getShingleSize()) .boundingBoxCacheFraction(TimeSeriesSettings.REAL_TIME_BOUNDING_BOX_CACHE_RATIO) - .timeDecay(AnomalyDetectorSettings.TIME_DECAY) + .timeDecay(TimeSeriesSettings.TIME_DECAY) .outputAfter(numMinSamples) .initialAcceptFraction(0.125d) .parallelExecutionEnabled(false) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) .internalShinglingEnabled(true) - .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE); + .anomalyRate(1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE); Tuple, ThresholdedRandomCutForest> models = MLUtil.prepareModel(inputDimension, rcfConfig); Queue samples = models.v1(); ThresholdedRandomCutForest rcf = models.v2(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); Random r = new Random(); @@ -528,8 +528,7 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { for (int i = 0; i < 100; i++) { double[] point = r.ints(inputDimension, 0, 50).asDoubleStream().toArray(); AnomalyDescriptor descriptor = rcf.process(point, 0); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(point, modelState, modelId, entity, detector.getShingleSize()); + ThresholdingResult result = modelManager.getResult(point, modelState, modelId, entity, detector.getShingleSize()); assertEquals(descriptor.getRCFScore(), result.getRcfScore(), 1e-10); assertEquals(descriptor.getAnomalyGrade(), result.getGrade(), 1e-10); } @@ -537,16 +536,16 @@ public void testTrainModelFromExistingSamplesEnoughSamples() { public void testTrainModelFromExistingSamplesNotEnoughSamples() { Queue samples = new ArrayDeque<>(); - EntityModel model = new EntityModel(entity, samples, null); - modelState = new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + modelState = new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); entityColdStarter.trainModelFromExistingSamples(modelState, detector.getShingleSize()); - assertTrue(!modelState.getModel().getTrcf().isPresent()); + assertTrue(!modelState.getModel().getModel().isPresent()); } @SuppressWarnings("unchecked") private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold, float recallThreshold) throws Exception { int baseDimension = 2; - int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int dataSize = 20 * TimeSeriesSettings.NUM_SAMPLES_PER_TREE; int trainTestSplit = 300; // detector interval int interval = detectorIntervalMins; @@ -597,7 +596,7 @@ private void accuracyTemplate(int detectorIntervalMins, float precisionThreshold ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(timestamps[0])); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); doAnswer(invocation -> { List> ranges = invocation.getArgument(1); @@ -621,8 +620,12 @@ public int compare(Entry p1, Entry p2) { return null; }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - modelState = new ModelState<>(model, modelId, detector.getId(), ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>( + entity, + new ArrayDeque<>(), + null + ); + modelState = new ADModelState<>(model, modelId, detector.getId(), ModelManager.ModelType.ENTITY.getName(), clock, priority); released = new AtomicBoolean(); @@ -635,7 +638,7 @@ public int compare(Entry p1, Entry p2) { entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(model.getModel().isPresent()); int tp = 0; int fp = 0; @@ -643,8 +646,7 @@ public int compare(Entry p1, Entry p2) { long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; for (int j = trainTestSplit; j < data.length; j++) { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + ThresholdingResult result = modelManager.getResult(data[j], modelState, modelId, entity, detector.getShingleSize()); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; @@ -692,41 +694,41 @@ public void testAccuracyThirteenMinuteInterval() throws Exception { public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { ADEnabledSetting.getInstance().setSettingValue(ADEnabledSetting.INTERPOLATION_IN_HCAD_COLD_START_ENABLED, false); // for one minute interval, we need to disable interpolation to achieve good results - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStarter( clock, threadPool, stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, imputer, searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, rcfSeed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS ); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, entityColdStarter, mock(FeatureManager.class), - mock(MemoryTracker.class), + mock(ADMemoryTracker.class), settings, clusterService ); @@ -734,7 +736,7 @@ public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception { accuracyTemplate(1, 0.6f, 0.6f); } - private ModelState createStateForCacheRelease() { + private ADModelState> createStateForCacheRelease() { inProgressLatch = new CountDownLatch(1); releaseSemaphore = () -> { released.set(true); @@ -742,17 +744,17 @@ private ModelState createStateForCacheRelease() { }; listener = ActionListener.wrap(releaseSemaphore); Queue samples = MLUtil.createQueueSamples(1); - EntityModel model = new EntityModel(entity, samples, null); - return new ModelState<>(model, modelId, detectorId, ModelType.ENTITY.getName(), clock, priority); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(entity, samples, null); + return new ADModelState<>(model, modelId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); } public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + ADModelState> modelState = createStateForCacheRelease(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); @@ -771,13 +773,13 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().getModel().isPresent()); modelState = createStateForCacheRelease(); entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); // model is not trained as the door keeper remembers it and won't retry training - assertTrue(!modelState.getModel().getTrcf().isPresent()); + assertTrue(!modelState.getModel().getModel().isPresent()); // make sure when the next maintenance coming, current door keeper gets reset // note our detector interval is 1 minute and the door keeper will expire in 60 intervals, which are 60 minutes @@ -788,16 +790,16 @@ public void testCacheReleaseAfterMaintenance() throws IOException, InterruptedEx entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper gets reset - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().getModel().isPresent()); } public void testCacheReleaseAfterClear() throws IOException, InterruptedException { - ModelState modelState = createStateForCacheRelease(); + ADModelState> modelState = createStateForCacheRelease(); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(1602269260000L)); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); List> coldStartSamples = new ArrayList<>(); @@ -816,7 +818,7 @@ public void testCacheReleaseAfterClear() throws IOException, InterruptedExceptio entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().getModel().isPresent()); entityColdStarter.clear(detectorId); @@ -824,6 +826,6 @@ public void testCacheReleaseAfterClear() throws IOException, InterruptedExceptio entityColdStarter.trainModel(entity, detectorId, modelState, listener); checkSemaphoreRelease(); // model is trained as the door keeper is regenerated after clearance - assertTrue(modelState.getModel().getTrcf().isPresent()); + assertTrue(modelState.getModel().getModel().isPresent()); } } diff --git a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java index 1f4afe829..d471494af 100644 --- a/src/test/java/org/opensearch/ad/ml/EntityModelTests.java +++ b/src/test/java/org/opensearch/ad/ml/EntityModelTests.java @@ -16,6 +16,7 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.createFromValueOnlySamples; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -29,45 +30,45 @@ public void setup() { } public void testNullInternalSampleQueue() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] { 0.8 }); - assertEquals(1, model.getSamples().size()); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(null, null, null); + model.addValueOnlySample(new double[] { 0.8 }); + assertEquals(1, model.getValueOnlySamples().size()); } public void testNullInputSample() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(null); - assertEquals(0, model.getSamples().size()); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(null, null, null); + model.addValueOnlySample(null); + assertEquals(0, model.getValueOnlySamples().size()); } public void testEmptyInputSample() { - EntityModel model = new EntityModel(null, null, null); - model.addSample(new double[] {}); - assertEquals(0, model.getSamples().size()); + createFromValueOnlySamples model = new createFromValueOnlySamples<>(null, null, null); + model.addValueOnlySample(new double[] {}); + assertEquals(0, model.getValueOnlySamples().size()); } @Test public void trcf_constructor() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); - assertEquals(trcf, em.getTrcf().get()); + createFromValueOnlySamples em = new createFromValueOnlySamples<>(null, new ArrayDeque<>(), trcf); + assertEquals(trcf, em.getModel().get()); } @Test public void clear() { - EntityModel em = new EntityModel(null, new ArrayDeque<>(), trcf); + createFromValueOnlySamples em = new createFromValueOnlySamples<>(null, new ArrayDeque<>(), trcf); em.clear(); - assertTrue(em.getSamples().isEmpty()); - assertFalse(em.getTrcf().isPresent()); + assertTrue(em.getValueOnlySamples().isEmpty()); + assertFalse(em.getModel().isPresent()); } @Test public void setTrcf() { - EntityModel em = new EntityModel(null, null, null); - assertFalse(em.getTrcf().isPresent()); + createFromValueOnlySamples em = new createFromValueOnlySamples<>(null, null, null); + assertFalse(em.getModel().isPresent()); - em.setTrcf(this.trcf); - assertTrue(em.getTrcf().isPresent()); + em.setModel(this.trcf); + assertTrue(em.getModel().isPresent()); } } diff --git a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java index 6fd32c2c9..8c74901a2 100644 --- a/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java +++ b/src/test/java/org/opensearch/ad/ml/HCADModelPerfTests.java @@ -34,10 +34,6 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -76,7 +72,7 @@ private void averageAccuracyTemplate( int baseDimension, boolean anomalyIndependent ) throws Exception { - int dataSize = 20 * AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE; + int dataSize = 20 * TimeSeriesSettings.NUM_SAMPLES_PER_TREE; int trainTestSplit = 300; // detector interval int interval = detectorIntervalMins; @@ -116,54 +112,52 @@ private void averageAccuracyTemplate( searchFeatureDao, imputer, clock, - AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, - AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.TRAIN_SAMPLE_TIME_RANGE_IN_HOURS, AnomalyDetectorSettings.MIN_TRAIN_SAMPLES, AnomalyDetectorSettings.MAX_SHINGLE_PROPORTION_MISSING, AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, imputer, searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, seed, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS ); - modelManager = new ModelManager( - mock(CheckpointDao.class), + modelManager = new ADModelManager( + mock(ADCheckpointDao.class), mock(Clock.class), - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.TIME_DECAY, - AnomalyDetectorSettings.NUM_MIN_SAMPLES, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.TIME_DECAY, + TimeSeriesSettings.NUM_MIN_SAMPLES, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, AnomalyDetectorSettings.MIN_PREVIEW_SIZE, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, entityColdStarter, mock(FeatureManager.class), - mock(MemoryTracker.class), + mock(ADMemoryTracker.class), settings, clusterService ); @@ -191,7 +185,7 @@ private void averageAccuracyTemplate( ActionListener> listener = invocation.getArgument(2); listener.onResponse(Optional.of(timestamps[0])); return null; - }).when(searchFeatureDao).getEntityMinDataTime(any(), any(), any()); + }).when(searchFeatureDao).getMinDataTime(any(), any(), any()); doAnswer(invocation -> { List> ranges = invocation.getArgument(1); @@ -216,12 +210,12 @@ public int compare(Entry p1, Entry p2) { }).when(searchFeatureDao).getColdStartSamplesForPeriods(any(), any(), any(), anyBoolean(), any()); entity = Entity.createSingleAttributeEntity("field", entityName + z); - EntityModel model = new EntityModel(entity, new ArrayDeque<>(), null); - ModelState modelState = new ModelState<>( + createFromValueOnlySamples model = new createFromValueOnlySamples(entity, new ArrayDeque<>(), null); + ADModelState modelState = new ADModelState<>( model, entity.getModelId(detectorId).get(), detector.getId(), - ModelType.ENTITY.getName(), + ModelManager.ModelType.ENTITY.getName(), clock, priority ); @@ -237,7 +231,7 @@ public int compare(Entry p1, Entry p2) { entityColdStarter.trainModel(entity, detector.getId(), modelState, listener); checkSemaphoreRelease(); - assertTrue(model.getTrcf().isPresent()); + assertTrue(model.getModel().isPresent()); int tp = 0; int fp = 0; @@ -245,8 +239,7 @@ public int compare(Entry p1, Entry p2) { long[] changeTimestamps = dataWithKeys.changeTimeStampsMs; for (int j = trainTestSplit; j < data.length; j++) { - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(data[j], modelState, modelId, entity, detector.getShingleSize()); + ThresholdingResult result = modelManager.getResult(data[j], modelState, modelId, entity, detector.getShingleSize()); if (result.getGrade() > 0) { if (changeTimestamps[j] == 0) { fp++; diff --git a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java index 7d981a297..023596351 100644 --- a/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/org/opensearch/ad/ml/ModelManagerTests.java @@ -55,15 +55,8 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.ActionListener; -import org.opensearch.ad.MemoryTracker; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SearchFeatureDao; -import org.opensearch.ad.ml.ModelManager.ModelType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -77,6 +70,7 @@ import org.opensearch.timeseries.common.exception.ResourceNotFoundException; import org.opensearch.timeseries.dataprocessor.LinearUniformImputer; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.powermock.modules.junit4.PowerMockRunner; import org.powermock.modules.junit4.PowerMockRunnerDelegate; @@ -94,7 +88,7 @@ @SuppressWarnings("unchecked") public class ModelManagerTests { - private ModelManager modelManager; + private ADModelManager modelManager; @Mock private AnomalyDetector anomalyDetector; @@ -106,7 +100,7 @@ public class ModelManagerTests { private JvmService jvmService; @Mock - private CheckpointDao checkpointDao; + private ADCheckpointDao checkpointDao; @Mock private Clock clock; @@ -115,16 +109,16 @@ public class ModelManagerTests { private FeatureManager featureManager; @Mock - private EntityColdStarter entityColdStarter; + private ADEntityColdStart entityColdStarter; @Mock private EntityCache cache; @Mock - private ModelState modelState; + private ADModelState modelState; @Mock - private EntityModel entityModel; + private createFromValueOnlySamples entityModel; @Mock private ThresholdedRandomCutForest trcf; @@ -164,11 +158,11 @@ public class ModelManagerTests { @Mock private ActionListener thresholdResultListener; - private MemoryTracker memoryTracker; + private ADMemoryTracker memoryTracker; private Instant now; @Mock - private ADCircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService adCircuitBreakerService; private String modelId = "modelId"; @@ -221,7 +215,7 @@ public void setup() { now = Instant.now(); when(clock.instant()).thenReturn(now); - memoryTracker = mock(MemoryTracker.class); + memoryTracker = mock(ADMemoryTracker.class); when(memoryTracker.isHostingAllowed(anyString(), any())).thenReturn(true); settings = Settings @@ -231,7 +225,7 @@ public void setup() { .build(); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -273,7 +267,7 @@ private Object[] getDetectorIdForModelIdData() { @Test @Parameters(method = "getDetectorIdForModelIdData") public void getDetectorIdForModelId_returnExpectedId(String modelId, String expectedDetectorId) { - assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getDetectorIdForModelId(modelId)); + assertEquals(expectedDetectorId, SingleStreamModelIdMapper.getConfigIdForModelId(modelId)); } private Object[] getDetectorIdForModelIdIllegalArgument() { @@ -283,7 +277,7 @@ private Object[] getDetectorIdForModelIdIllegalArgument() { @Test(expected = IllegalArgumentException.class) @Parameters(method = "getDetectorIdForModelIdIllegalArgument") public void getDetectorIdForModelId_throwIllegalArgument_forInvalidId(String modelId) { - SingleStreamModelIdMapper.getDetectorIdForModelId(modelId); + SingleStreamModelIdMapper.getConfigIdForModelId(modelId); } private Map createDataNodes(int numDataNodes) { @@ -413,7 +407,7 @@ public void getRcfResult_throwToListener_whenHeapLimitExceed() { when(jvmService.info().getMem().getHeapMax().getBytes()).thenReturn(1_000L); - MemoryTracker memoryTracker = new MemoryTracker( + ADMemoryTracker memoryTracker = new ADMemoryTracker( jvmService, modelMaxSizePercentage, modelDesiredSizePercentage, @@ -425,7 +419,7 @@ public void getRcfResult_throwToListener_whenHeapLimitExceed() { // use new memoryTracker modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -880,7 +874,8 @@ public void getPreviewResults_throwIllegalArgument_forInvalidInput() { @Test public void processEmptyCheckpoint() { - ModelState modelState = modelManager.processEntityCheckpoint(Optional.empty(), null, "", "", shingleSize); + ADModelState modelState = modelManager + .processEntityCheckpoint(Optional.empty(), null, "", "", shingleSize); assertEquals(Instant.MIN, modelState.getLastCheckpointTime()); } @@ -888,9 +883,9 @@ public void processEmptyCheckpoint() { public void processNonEmptyCheckpoint() { String modelId = "abc"; String detectorId = "123"; - EntityModel model = MLUtil.createNonEmptyModel(modelId); + createFromValueOnlySamples model = MLUtil.createNonEmptyModel(modelId); Instant checkpointTime = Instant.ofEpochMilli(1000); - ModelState modelState = modelManager + ADModelState modelState = modelManager .processEntityCheckpoint( Optional.of(new SimpleImmutableEntry<>(model, checkpointTime)), null, @@ -899,13 +894,13 @@ public void processNonEmptyCheckpoint() { shingleSize ); assertEquals(checkpointTime, modelState.getLastCheckpointTime()); - assertEquals(model.getSamples().size(), modelState.getModel().getSamples().size()); + assertEquals(model.getValueOnlySamples().size(), modelState.getModel().getValueOnlySamples().size()); assertEquals(now, modelState.getLastUsedTime()); } @Test public void getNullState() { - assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getAnomalyResultForEntity(new double[] {}, null, "", null, shingleSize)); + assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getResult(new double[] {}, null, "", null, shingleSize)); } @Test @@ -914,7 +909,7 @@ public void getEmptyStateFullSamples() { LinearUniformImputer interpolator = new LinearUniformImputer(true); - NodeStateManager stateManager = mock(NodeStateManager.class); + ADNodeStateManager stateManager = mock(ADNodeStateManager.class); featureManager = new FeatureManager( searchFeatureDao, interpolator, @@ -927,35 +922,35 @@ public void getEmptyStateFullSamples() { AnomalyDetectorSettings.MAX_IMPUTATION_NEIGHBOR_DISTANCE, AnomalyDetectorSettings.PREVIEW_SAMPLE_RATE, AnomalyDetectorSettings.MAX_PREVIEW_SAMPLES, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, threadPool, TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME ); - CheckpointWriteWorker checkpointWriteQueue = mock(CheckpointWriteWorker.class); + ADCheckpointWriteWorker checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - entityColdStarter = new EntityColdStarter( + entityColdStarter = new ADEntityColdStart( clock, threadPool, stateManager, - AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE, - AnomalyDetectorSettings.NUM_TREES, - AnomalyDetectorSettings.TIME_DECAY, + TimeSeriesSettings.NUM_SAMPLES_PER_TREE, + TimeSeriesSettings.NUM_TREES, + TimeSeriesSettings.TIME_DECAY, numMinSamples, AnomalyDetectorSettings.MAX_SAMPLE_STRIDE, AnomalyDetectorSettings.MAX_TRAIN_SAMPLE, interpolator, searchFeatureDao, - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE, + TimeSeriesSettings.THRESHOLD_MIN_PVALUE, featureManager, settings, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, - AnomalyDetectorSettings.MAX_COLD_START_ROUNDS + TimeSeriesSettings.MAX_COLD_START_ROUNDS ); modelManager = spy( - new ModelManager( + new ADModelManager( checkpointDao, clock, numTrees, @@ -974,50 +969,49 @@ public void getEmptyStateFullSamples() { ) ); - ModelState state = MLUtil + ADModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples).build()); - EntityModel model = state.getModel(); + createFromValueOnlySamples model = state.getModel(); assertTrue(!model.getTrcf().isPresent()); - ThresholdingResult result = modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); + ThresholdingResult result = modelManager.getResult(new double[] { -1 }, state, "", null, shingleSize); // model outputs scores assertTrue(result.getRcfScore() != 0); // added the sample to score since our model is empty - assertEquals(0, model.getSamples().size()); + assertEquals(0, model.getValueOnlySamples().size()); } @Test public void getAnomalyResultForEntityNoModel() { - ModelState modelState = new ModelState<>(null, modelId, detectorId, ModelType.ENTITY.getName(), clock, 0); + ADModelState modelState = new ADModelState<>( + null, + modelId, + detectorId, + ModelManager.ModelType.ENTITY.getName(), + clock, + 0 + ); ThresholdingResult result = modelManager - .getAnomalyResultForEntity( - new double[] { -1 }, - modelState, - modelId, - Entity.createSingleAttributeEntity("field", "val"), - shingleSize - ); + .getResult(new double[] { -1 }, modelState, modelId, Entity.createSingleAttributeEntity("field", "val"), shingleSize); // model outputs scores assertEquals(new ThresholdingResult(0, 0, 0), result); // added the sample to score since our model is empty - assertEquals(1, modelState.getModel().getSamples().size()); + assertEquals(1, modelState.getModel().getValueOnlySamples().size()); } @Test public void getEmptyStateNotFullSamples() { - ModelState state = MLUtil + ADModelState state = MLUtil .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).sampleSize(numMinSamples - 1).build()); - assertEquals( - new ThresholdingResult(0, 0, 0), - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize) - ); - assertEquals(numMinSamples, state.getModel().getSamples().size()); + assertEquals(new ThresholdingResult(0, 0, 0), modelManager.getResult(new double[] { -1 }, state, "", null, shingleSize)); + assertEquals(numMinSamples, state.getModel().getValueOnlySamples().size()); } @Test public void scoreSamples() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - modelManager.getAnomalyResultForEntity(new double[] { -1 }, state, "", null, shingleSize); - assertEquals(0, state.getModel().getSamples().size()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + modelManager.getResult(new double[] { -1 }, state, "", null, shingleSize); + assertEquals(0, state.getModel().getValueOnlySamples().size()); assertEquals(now, state.getLastUsedTime()); } @@ -1028,8 +1022,7 @@ public void getAnomalyResultForEntity_withTrcf() { anomalyDescriptor.setAnomalyGrade(1); when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); - ThresholdingResult result = modelManager - .getAnomalyResultForEntity(this.point, this.modelState, this.detectorId, null, this.shingleSize); + ThresholdingResult result = modelManager.getResult(this.point, this.modelState, this.detectorId, null, this.shingleSize); assertEquals( new ThresholdingResult( anomalyDescriptor.getAnomalyGrade(), @@ -1052,7 +1045,7 @@ public void score_with_trcf() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); when(this.trcf.process(this.point, 0)).thenReturn(anomalyDescriptor); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + when(this.entityModel.getValueOnlySamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); ThresholdingResult result = modelManager.score(this.point, this.detectorId, this.modelState); assertEquals( @@ -1085,7 +1078,7 @@ public void score_throw() { when(rcf.getDimensions()).thenReturn(40); when(this.trcf.getForest()).thenReturn(rcf); doThrow(new IllegalArgumentException()).when(trcf).process(any(), anyLong()); - when(this.entityModel.getSamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); + when(this.entityModel.getValueOnlySamples()).thenReturn(new ArrayDeque<>(Arrays.asList(this.point))); modelManager.score(this.point, this.detectorId, this.modelState); } } diff --git a/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java b/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java index 59a0d02da..bda02043a 100644 --- a/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java +++ b/src/test/java/org/opensearch/ad/ml/SingleStreamModelIdMapperTests.java @@ -12,6 +12,7 @@ package org.opensearch.ad.ml; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; public class SingleStreamModelIdMapperTests extends OpenSearchTestCase { public void testGetThresholdModelIdFromRCFModelId() { diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java index 327e3bf51..9a58b9c4f 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobAction.java @@ -12,16 +12,16 @@ package org.opensearch.ad.mock.transport; import org.opensearch.action.ActionType; -import org.opensearch.ad.constant.CommonValue; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.constant.ADCommonValue; +import org.opensearch.timeseries.transport.JobResponse; -public class MockAnomalyDetectorJobAction extends ActionType { +public class MockAnomalyDetectorJobAction extends ActionType { // External Action which used for public facing RestAPIs. - public static final String NAME = CommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; + public static final String NAME = ADCommonValue.EXTERNAL_ACTION_PREFIX + "detector/mockjobmanagement"; public static final MockAnomalyDetectorJobAction INSTANCE = new MockAnomalyDetectorJobAction(); private MockAnomalyDetectorJobAction() { - super(NAME, AnomalyDetectorJobResponse::new); + super(NAME, JobResponse::new); } } diff --git a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java index 15d37c89d..7b6b4d865 100644 --- a/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java +++ b/src/test/java/org/opensearch/ad/mock/transport/MockAnomalyDetectorJobTransportActionWithUser.java @@ -11,8 +11,8 @@ package org.opensearch.ad.mock.transport; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.timeseries.util.ParseUtils.resolveUserAndExecute; import org.apache.logging.log4j.LogManager; @@ -21,12 +21,13 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.ad.ExecuteADResultResponseRecorder; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; -import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobRequest; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.transport.AnomalyDetectorJobTransportAction; +import org.opensearch.ad.transport.GetAnomalyDetectorResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; @@ -37,11 +38,14 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.tasks.Task; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.util.RestHandlerUtils; import org.opensearch.transport.TransportService; -public class MockAnomalyDetectorJobTransportActionWithUser extends - HandledTransportAction { +public class MockAnomalyDetectorJobTransportActionWithUser extends HandledTransportAction { private final Logger logger = LogManager.getLogger(AnomalyDetectorJobTransportAction.class); private final Client client; @@ -51,7 +55,7 @@ public class MockAnomalyDetectorJobTransportActionWithUser extends private final NamedXContentRegistry xContentRegistry; private volatile Boolean filterByEnabled; private ThreadContext.StoredContext context; - private final ADTaskManager adTaskManager; + private final TaskManager adTaskManager; private final TransportService transportService; private final ExecuteADResultResponseRecorder recorder; @@ -64,10 +68,10 @@ public MockAnomalyDetectorJobTransportActionWithUser( Settings settings, ADIndexManagement anomalyDetectionIndices, NamedXContentRegistry xContentRegistry, - ADTaskManager adTaskManager, + TaskManager adTaskManager, ExecuteADResultResponseRecorder recorder ) { - super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, AnomalyDetectorJobRequest::new); + super(MockAnomalyDetectorJobAction.NAME, transportService, actionFilters, JobRequest::new); this.transportService = transportService; this.client = client; this.clusterService = clusterService; @@ -75,8 +79,8 @@ public MockAnomalyDetectorJobTransportActionWithUser( this.anomalyDetectionIndices = anomalyDetectionIndices; this.xContentRegistry = xContentRegistry; this.adTaskManager = adTaskManager; - filterByEnabled = FILTER_BY_BACKEND_ROLES.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); + filterByEnabled = AD_FILTER_BY_BACKEND_ROLES.get(settings); + clusterService.getClusterSettings().addSettingsUpdateConsumer(AD_FILTER_BY_BACKEND_ROLES, it -> filterByEnabled = it); ThreadContext threadContext = new ThreadContext(settings); context = threadContext.stashContext(); @@ -84,14 +88,14 @@ public MockAnomalyDetectorJobTransportActionWithUser( } @Override - protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionListener listener) { - String detectorId = request.getDetectorID(); - DateRange detectionDateRange = request.getDetectionDateRange(); + protected void doExecute(Task task, JobRequest request, ActionListener listener) { + String detectorId = request.getConfigID(); + DateRange detectionDateRange = request.getDateRange(); boolean historical = request.isHistorical(); long seqNo = request.getSeqNo(); long primaryTerm = request.getPrimaryTerm(); String rawPath = request.getRawPath(); - TimeValue requestTimeout = REQUEST_TIMEOUT.get(settings); + TimeValue requestTimeout = AD_REQUEST_TIMEOUT.get(settings); String userStr = "user_name|backendrole1,backendrole2|roles1,role2"; // By the time request reaches here, the user permissions are validated by Security plugin. User user = User.parse(userStr); @@ -114,7 +118,8 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis ), client, clusterService, - xContentRegistry + xContentRegistry, + GetAnomalyDetectorResponse.class ); } catch (Exception e) { logger.error(e); @@ -123,7 +128,7 @@ protected void doExecute(Task task, AnomalyDetectorJobRequest request, ActionLis } private void executeDetector( - ActionListener listener, + ActionListener listener, String detectorId, long seqNo, long primaryTerm, @@ -133,7 +138,7 @@ private void executeDetector( DateRange detectionDateRange, boolean historical ) { - IndexAnomalyDetectorJobActionHandler handler = new IndexAnomalyDetectorJobActionHandler( + IndexJobActionHandler handler = new IndexJobActionHandler( client, anomalyDetectionIndices, detectorId, diff --git a/src/test/java/org/opensearch/ad/model/ADTaskTests.java b/src/test/java/org/opensearch/ad/model/ADTaskTests.java index 1cd2e6cc8..d97dc15dd 100644 --- a/src/test/java/org/opensearch/ad/model/ADTaskTests.java +++ b/src/test/java/org/opensearch/ad/model/ADTaskTests.java @@ -25,6 +25,7 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.TaskState; public class ADTaskTests extends OpenSearchSingleNodeTestCase { @@ -39,7 +40,7 @@ protected NamedWriteableRegistry writableRegistry() { } public void testAdTaskSerialization() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), true); + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), TaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), true); BytesStreamOutput output = new BytesStreamOutput(); adTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); @@ -48,7 +49,7 @@ public void testAdTaskSerialization() throws IOException { } public void testAdTaskSerializationWithNullDetector() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), ADTaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), false); + ADTask adTask = TestHelpers.randomAdTask(randomAlphaOfLength(5), TaskState.STOPPED, Instant.now(), randomAlphaOfLength(5), false); BytesStreamOutput output = new BytesStreamOutput(); adTask.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); @@ -58,7 +59,7 @@ public void testAdTaskSerializationWithNullDetector() throws IOException { public void testParseADTask() throws IOException { ADTask adTask = TestHelpers - .randomAdTask(null, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + .randomAdTask(null, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); String taskId = randomAlphaOfLength(5); adTask.setTaskId(taskId); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); @@ -69,7 +70,7 @@ public void testParseADTask() throws IOException { public void testParseADTaskWithoutTaskId() throws IOException { String taskId = null; ADTask adTask = TestHelpers - .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); + .randomAdTask(taskId, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), true); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString)); assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); @@ -78,7 +79,7 @@ public void testParseADTaskWithoutTaskId() throws IOException { public void testParseADTaskWithNullDetector() throws IOException { String taskId = randomAlphaOfLength(5); ADTask adTask = TestHelpers - .randomAdTask(taskId, ADTaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), false); + .randomAdTask(taskId, TaskState.STOPPED, Instant.now().truncatedTo(ChronoUnit.SECONDS), randomAlphaOfLength(5), false); String adTaskString = TestHelpers.xContentBuilderToString(adTask.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); ADTask parsedADTask = ADTask.parse(TestHelpers.parser(adTaskString), taskId); assertEquals("Parsing AD task doesn't work", adTask, parsedADTask); diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java index 75d821507..897a344ba 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorJobTests.java @@ -24,6 +24,8 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.model.Job; + public class AnomalyDetectorJobTests extends OpenSearchSingleNodeTestCase { @@ -38,22 +40,22 @@ protected NamedWriteableRegistry writableRegistry() { } public void testParseAnomalyDetectorJob() throws IOException { - AnomalyDetectorJob anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); + Job anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); String anomalyDetectorJobString = TestHelpers .xContentBuilderToString(anomalyDetectorJob.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); anomalyDetectorJobString = anomalyDetectorJobString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetectorJob parsedAnomalyDetectorJob = AnomalyDetectorJob.parse(TestHelpers.parser(anomalyDetectorJobString)); + Job parsedAnomalyDetectorJob = Job.parse(TestHelpers.parser(anomalyDetectorJobString)); assertEquals("Parsing anomaly detect result doesn't work", anomalyDetectorJob, parsedAnomalyDetectorJob); } public void testSerialization() throws IOException { - AnomalyDetectorJob anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); + Job anomalyDetectorJob = TestHelpers.randomAnomalyDetectorJob(); BytesStreamOutput output = new BytesStreamOutput(); anomalyDetectorJob.writeTo(output); NamedWriteableAwareStreamInput input = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), writableRegistry()); - AnomalyDetectorJob parsedAnomalyDetectorJob = new AnomalyDetectorJob(input); + Job parsedAnomalyDetectorJob = new Job(input); assertNotNull(parsedAnomalyDetectorJob); } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java index d3298eae2..3a0f87f80 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyDetectorTests.java @@ -11,10 +11,8 @@ package org.opensearch.ad.model; -import static org.opensearch.ad.constant.ADCommonMessages.INVALID_RESULT_INDEX_PREFIX; import static org.opensearch.ad.constant.ADCommonName.CUSTOM_RESULT_INDEX_PREFIX; import static org.opensearch.ad.model.AnomalyDetector.MAX_RESULT_INDEX_NAME_SIZE; -import static org.opensearch.timeseries.constant.CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME; import java.io.IOException; import java.time.Instant; @@ -30,6 +28,8 @@ import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.ValidationException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.model.Config; import org.opensearch.timeseries.model.IntervalTimeConfiguration; import org.opensearch.timeseries.settings.TimeSeriesSettings; @@ -39,18 +39,18 @@ public class AnomalyDetectorTests extends AbstractTimeSeriesTest { public void testParseAnomalyDetector() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + Config detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); LOG.info(detectorString); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } public void testParseAnomalyDetectorWithCustomIndex() throws IOException { String resultIndex = ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "test"; - AnomalyDetector detector = TestHelpers + Config detector = TestHelpers .randomDetector( ImmutableList.of(TestHelpers.randomFeature()), randomAlphaOfLength(5), @@ -63,7 +63,7 @@ public void testParseAnomalyDetectorWithCustomIndex() throws IOException { LOG.info(detectorString); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing result index doesn't work", resultIndex, parsedDetector.getCustomResultIndex()); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } @@ -86,30 +86,30 @@ public void testAnomalyDetectorWithInvalidCustomIndex() throws Exception { } public void testParseAnomalyDetectorWithoutParams() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + Config detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder())); LOG.info(detectorString); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } public void testParseAnomalyDetectorWithCustomDetectionDelay() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); + Config detector = TestHelpers.randomAnomalyDetector(TestHelpers.randomUiMetadata(), Instant.now()); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder())); LOG.info(detectorString); TimeValue detectionInterval = new TimeValue(1, TimeUnit.MINUTES); TimeValue detectionWindowDelay = new TimeValue(10, TimeUnit.MINUTES); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector + Config parsedDetector = Config .parse(TestHelpers.parser(detectorString), detector.getId(), detector.getVersion(), detectionInterval, detectionWindowDelay); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } public void testParseSingleEntityAnomalyDetector() throws IOException { - AnomalyDetector detector = TestHelpers + Config detector = TestHelpers .randomAnomalyDetector( ImmutableList.of(TestHelpers.randomFeature()), TestHelpers.randomUiMetadata(), @@ -120,12 +120,12 @@ public void testParseSingleEntityAnomalyDetector() throws IOException { LOG.info(detectorString); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } public void testParseHistoricalAnomalyDetectorWithoutUser() throws IOException { - AnomalyDetector detector = TestHelpers + Config detector = TestHelpers .randomAnomalyDetector( ImmutableList.of(TestHelpers.randomFeature()), TestHelpers.randomUiMetadata(), @@ -137,7 +137,7 @@ public void testParseHistoricalAnomalyDetectorWithoutUser() throws IOException { LOG.info(detectorString); detectorString = detectorString .replaceFirst("\\{", String.format(Locale.ROOT, "{\"%s\":\"%s\",", randomAlphaOfLength(5), randomAlphaOfLength(5))); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } @@ -150,7 +150,7 @@ public void testParseAnomalyDetectorWithNullFilterQuery() throws IOException { + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); } @@ -164,7 +164,7 @@ public void testParseAnomalyDetectorWithEmptyFilterQuery() throws IOException { + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}}," + "\"last_update_time\":1568396089028}"; - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); } @@ -189,7 +189,7 @@ public void testParseAnomalyDetectorWithoutOptionalParams() throws IOException { + "{\"period\":{\"interval\":425,\"unit\":\"Minutes\"}},\"schema_version\":-1203962153,\"ui_metadata\":" + "{\"JbAaV\":{\"feature_id\":\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false," + "\"aggregation_query\":{\"aa\":{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertTrue(parsedDetector.getFilterQuery() instanceof MatchAllQueryBuilder); assertEquals((long) parsedDetector.getShingleSize(), (long) TimeSeriesSettings.DEFAULT_SHINGLE_SIZE); } @@ -252,7 +252,7 @@ public void testParseAnomalyDetectorWithInvalidDetectorIntervalUnits() { + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> AnomalyDetector.parse(TestHelpers.parser(detectorString)) + () -> Config.parse(TestHelpers.parser(detectorString)) ); assertEquals( String.format(Locale.ROOT, ADCommonMessages.INVALID_TIME_CONFIGURATION_UNITS, ChronoUnit.MILLIS), @@ -271,7 +271,7 @@ public void testParseAnomalyDetectorInvalidWindowDelayUnits() { + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> AnomalyDetector.parse(TestHelpers.parser(detectorString)) + () -> Config.parse(TestHelpers.parser(detectorString)) ); assertEquals( String.format(Locale.ROOT, ADCommonMessages.INVALID_TIME_CONFIGURATION_UNITS, ChronoUnit.MILLIS), @@ -280,17 +280,17 @@ public void testParseAnomalyDetectorInvalidWindowDelayUnits() { } public void testParseAnomalyDetectorWithNullUiMetadata() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); + Config detector = TestHelpers.randomAnomalyDetector(null, Instant.now()); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); assertNull(parsedDetector.getUiMetadata()); } public void testParseAnomalyDetectorWithEmptyUiMetadata() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); + Config detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of(), Instant.now()); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals("Parsing anomaly detector doesn't work", detector, parsedDetector); } @@ -513,7 +513,7 @@ public void testInvalidDetectionInterval() { public void testInvalidWindowDelay() { IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> new AnomalyDetector( + () -> new Config( randomAlphaOfLength(10), randomLong(), randomAlphaOfLength(20), @@ -538,22 +538,21 @@ public void testInvalidWindowDelay() { } public void testNullFeatures() throws IOException { - AnomalyDetector detector = TestHelpers.randomAnomalyDetector(null, null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); + Config detector = TestHelpers.randomAnomalyDetector(null, null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals(0, parsedDetector.getFeatureAttributes().size()); } public void testEmptyFeatures() throws IOException { - AnomalyDetector detector = TestHelpers - .randomAnomalyDetector(ImmutableList.of(), null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); + Config detector = TestHelpers.randomAnomalyDetector(ImmutableList.of(), null, Instant.now().truncatedTo(ChronoUnit.SECONDS)); String detectorString = TestHelpers.xContentBuilderToString(detector.toXContent(TestHelpers.builder(), ToXContent.EMPTY_PARAMS)); - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString)); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString)); assertEquals(0, parsedDetector.getFeatureAttributes().size()); } public void testGetShingleSize() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new Config( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -577,7 +576,7 @@ public void testGetShingleSize() throws IOException { } public void testGetShingleSizeReturnsDefaultValue() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new Config( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -601,7 +600,7 @@ public void testGetShingleSizeReturnsDefaultValue() throws IOException { } public void testNullFeatureAttributes() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( + Config anomalyDetector = new Config( randomAlphaOfLength(5), randomLong(), randomAlphaOfLength(5), @@ -625,43 +624,22 @@ public void testNullFeatureAttributes() throws IOException { assertEquals(0, anomalyDetector.getFeatureAttributes().size()); } - public void testValidateResultIndex() throws IOException { - AnomalyDetector anomalyDetector = new AnomalyDetector( - randomAlphaOfLength(5), - randomLong(), - randomAlphaOfLength(5), - randomAlphaOfLength(5), - randomAlphaOfLength(5), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(TestHelpers.randomFeature()), - TestHelpers.randomQuery(), - TestHelpers.randomIntervalTimeConfiguration(), - TestHelpers.randomIntervalTimeConfiguration(), - null, - null, - 1, - Instant.now(), - null, - TestHelpers.randomUser(), - null, - TestHelpers.randomImputationOption() - ); - - String errorMessage = anomalyDetector.validateCustomResultIndex("abc"); - assertEquals(INVALID_RESULT_INDEX_PREFIX, errorMessage); + public void testValidateResultIndex() { + String errorMessage = Config.validateCustomResultIndex("abc"); + assertEquals(ADCommonMessages.INVALID_RESULT_INDEX_PREFIX, errorMessage); StringBuilder resultIndexNameBuilder = new StringBuilder(CUSTOM_RESULT_INDEX_PREFIX); for (int i = 0; i < MAX_RESULT_INDEX_NAME_SIZE - CUSTOM_RESULT_INDEX_PREFIX.length(); i++) { resultIndexNameBuilder.append("a"); } - assertNull(anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString())); + assertNull(Config.validateCustomResultIndex(resultIndexNameBuilder.toString())); resultIndexNameBuilder.append("a"); - errorMessage = anomalyDetector.validateCustomResultIndex(resultIndexNameBuilder.toString()); - assertEquals(AnomalyDetector.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); + errorMessage = Config.validateCustomResultIndex(resultIndexNameBuilder.toString()); + assertEquals(Config.INVALID_RESULT_INDEX_NAME_SIZE, errorMessage); - errorMessage = anomalyDetector.validateCustomResultIndex(CUSTOM_RESULT_INDEX_PREFIX + "abc#"); - assertEquals(INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); + errorMessage = Config.validateCustomResultIndex(CUSTOM_RESULT_INDEX_PREFIX + "abc#"); + assertEquals(CommonMessages.INVALID_CHAR_IN_RESULT_INDEX_NAME, errorMessage); } public void testParseAnomalyDetectorWithNoDescription() throws IOException { @@ -672,7 +650,7 @@ public void testParseAnomalyDetectorWithNoDescription() throws IOException { + "\"unit\":\"Minutes\"}},\"shingle_size\":4,\"schema_version\":-1203962153,\"ui_metadata\":{\"JbAaV\":{\"feature_id\":" + "\"rIFjS\",\"feature_name\":\"QXCmS\",\"feature_enabled\":false,\"aggregation_query\":{\"aa\":" + "{\"value_count\":{\"field\":\"ok\"}}}}},\"last_update_time\":1568396089028}"; - AnomalyDetector parsedDetector = AnomalyDetector.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); + Config parsedDetector = Config.parse(TestHelpers.parser(detectorString), "id", 1L, null, null); assertEquals(parsedDetector.getDescription(), ""); } } diff --git a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java index 424de19da..28245aa31 100644 --- a/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/AnomalyResultTests.java @@ -11,8 +11,6 @@ package org.opensearch.ad.model; -import static org.opensearch.test.OpenSearchTestCase.randomDouble; - import java.io.IOException; import java.util.Collection; import java.util.Locale; diff --git a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java index 9960a5fe2..dc41882d5 100644 --- a/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java +++ b/src/test/java/org/opensearch/ad/model/DetectorProfileTests.java @@ -21,13 +21,15 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.ConfigState; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; public class DetectorProfileTests extends OpenSearchTestCase { private DetectorProfile createRandomDetectorProfile() { return new DetectorProfile.Builder() - .state(DetectorState.INIT) + .state(ConfigState.INIT) .error(randomAlphaOfLength(5)) .modelProfile( new ModelProfileOnNode[] { diff --git a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java index 24cb0c879..addf84022 100644 --- a/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/model/EntityAnomalyResultTests.java @@ -18,7 +18,7 @@ import java.util.List; import org.junit.Test; -import org.opensearch.ad.stats.ADStatsResponse; +import org.opensearch.ad.transport.ADStatsResponse; import org.opensearch.test.OpenSearchTestCase; public class EntityAnomalyResultTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/ad/model/MergeableListTests.java b/src/test/java/org/opensearch/ad/model/MergeableListTests.java index f9d794da6..1375d72a7 100644 --- a/src/test/java/org/opensearch/ad/model/MergeableListTests.java +++ b/src/test/java/org/opensearch/ad/model/MergeableListTests.java @@ -15,6 +15,7 @@ import java.util.List; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.model.MergeableList; public class MergeableListTests extends AbstractTimeSeriesTest { diff --git a/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java index 075cf46c3..d6d7d40da 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java +++ b/src/test/java/org/opensearch/ad/ratelimit/AbstractRateLimitingTest.java @@ -22,7 +22,7 @@ import java.util.Optional; import org.opensearch.action.ActionListener; -import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; @@ -32,7 +32,7 @@ public class AbstractRateLimitingTest extends AbstractTimeSeriesTest { Clock clock; AnomalyDetector detector; - NodeStateManager nodeStateManager; + ADNodeStateManager nodeStateManager; String detectorId; String categoryField; Entity entity, entity2, entity3; @@ -52,12 +52,12 @@ public void setUp() throws Exception { detectorId = "123"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(categoryField)); - nodeStateManager = mock(NodeStateManager.class); + nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); entity = Entity.createSingleAttributeEntity(categoryField, "value"); entity2 = Entity.createSingleAttributeEntity(categoryField, "value2"); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java index d1fe526de..e1c9fb873 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckPointMaintainRequestAdapterTests.java @@ -27,37 +27,40 @@ import java.util.Optional; import org.opensearch.action.update.UpdateRequest; -import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.ml.createFromValueOnlySamples; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.CheckpointWriteRequest; +import org.opensearch.timeseries.ratelimit.ModelRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckPointMaintainRequestAdapterTests extends AbstractRateLimitingTest { - private CacheProvider cache; - private CheckpointDao checkpointDao; + private HCCacheProvider cache; + private ADCheckpointDao checkpointDao; private String indexName; private Setting checkpointInterval; private CheckPointMaintainRequestAdapter adapter; - private ModelState state; - private CheckpointMaintainRequest request; + private ADModelState state; + private ModelRequest request; private ClusterService clusterService; @Override public void setUp() throws Exception { super.setUp(); - cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + cache = mock(HCCacheProvider.class); + checkpointDao = mock(ADCheckpointDao.class); indexName = ADCommonName.CHECKPOINT_INDEX_NAME; checkpointInterval = AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; EntityCache entityCache = mock(EntityCache.class); @@ -79,7 +82,7 @@ public void setUp() throws Exception { clusterService, Settings.EMPTY ); - request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity.getModelId(detectorId).get()); + request = new ModelRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, entity.getModelId(detectorId).get()); } @@ -105,7 +108,7 @@ public void testModelIdEmpty() throws IOException { Map content = new HashMap(); content.put("a", "b"); when(checkpointDao.toIndexSource(any())).thenReturn(content); - assertTrue(adapter.convert(new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, null)).isEmpty()); + assertTrue(adapter.convert(new ModelRequest(Integer.MAX_VALUE, detectorId, RequestPriority.MEDIUM, null)).isEmpty()); } public void testNormal() throws IOException { diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java index cba7e8a45..3d4848313 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointMaintainWorkerTests.java @@ -32,31 +32,34 @@ import java.util.Optional; import java.util.Random; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.ml.createFromValueOnlySamples; +import org.opensearch.timeseries.ratelimit.CheckPointMaintainRequestAdapter; +import org.opensearch.timeseries.ratelimit.ModelRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointMaintainWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - CheckpointMaintainWorker cpMaintainWorker; - CheckpointWriteWorker writeWorker; - CheckpointMaintainRequest request; - CheckpointMaintainRequest request2; - List requests; - CheckpointDao checkpointDao; + ADCheckpointMaintainWorker cpMaintainWorker; + ADCheckpointWriteWorker writeWorker; + ModelRequest request; + ModelRequest request2; + List requests; + ADCheckpointDao checkpointDao; @Override public void setUp() throws Exception { @@ -71,7 +74,7 @@ public void setUp() throws Exception { Arrays .asList( AnomalyDetectorSettings.AD_EXPECTED_CHECKPOINT_MAINTAIN_TIME_IN_MILLISECS, - AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ ) @@ -80,15 +83,16 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - writeWorker = mock(CheckpointWriteWorker.class); + writeWorker = mock(ADCheckpointWriteWorker.class); - CacheProvider cache = mock(CacheProvider.class); - checkpointDao = mock(CheckpointDao.class); + HCCacheProvider cache = mock(HCCacheProvider.class); + checkpointDao = mock(ADCheckpointDao.class); String indexName = ADCommonName.CHECKPOINT_INDEX_NAME; Setting checkpointInterval = AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ; EntityCache entityCache = mock(EntityCache.class); when(cache.get()).thenReturn(entityCache); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.getForMaintainance(anyString(), anyString())).thenReturn(Optional.of(state)); CheckPointMaintainRequestAdapter adapter = new CheckPointMaintainRequestAdapter( cache, @@ -101,28 +105,28 @@ public void setUp() throws Exception { ); // Integer.MAX_VALUE makes a huge heap - cpMaintainWorker = new CheckpointMaintainWorker( + cpMaintainWorker = new ADCheckpointMaintainWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_MAINTAIN_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, writeWorker, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, adapter ); - request = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity.getModelId(detectorId).get()); - request2 = new CheckpointMaintainRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2.getModelId(detectorId).get()); + request = new ModelRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity.getModelId(detectorId).get()); + request2 = new ModelRequest(Integer.MAX_VALUE, detectorId, RequestPriority.LOW, entity2.getModelId(detectorId).get()); requests = new ArrayList<>(); requests.add(request); diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java index 76090cce9..74b8e0338 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointReadWorkerTests.java @@ -46,21 +46,13 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.get.MultiGetItemResponse; import org.opensearch.action.get.MultiGetResponse; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -75,7 +67,9 @@ import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -83,19 +77,19 @@ import com.fasterxml.jackson.core.JsonParseException; public class CheckpointReadWorkerTests extends AbstractRateLimitingTest { - CheckpointReadWorker worker; + ADCheckpointReadWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ADModelState state; - CheckpointWriteWorker checkpointWriteQueue; - ModelManager modelManager; - EntityColdStartWorker coldstartQueue; - ResultWriteWorker resultWriteQueue; + ADCheckpointWriteWorker checkpointWriteQueue; + ADModelManager modelManager; + ADColdStartWorker coldstartQueue; + ADResultWriteWorker resultWriteQueue; ADIndexManagement anomalyDetectionIndices; - CacheProvider cacheProvider; + EntityCacheProvider cacheProvider; EntityCache entityCache; EntityFeatureRequest request, request2, request3; ClusterSettings clusterSettings; @@ -112,7 +106,7 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE ) @@ -123,50 +117,53 @@ public void setUp() throws Exception { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); - Map.Entry entry = new SimpleImmutableEntry(state.getModel(), Instant.now()); + Map.Entry entry = new SimpleImmutableEntry( + state.getModel(), + Instant.now() + ); when(checkpoint.processGetResponse(any(), anyString())).thenReturn(Optional.of(entry)); - checkpointWriteQueue = mock(CheckpointWriteWorker.class); + checkpointWriteQueue = mock(ADCheckpointWriteWorker.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); when(modelManager.score(any(), anyString(), any())).thenReturn(new ThresholdingResult(0, 1, 0.7)); - coldstartQueue = mock(EntityColdStartWorker.class); - resultWriteQueue = mock(ResultWriteWorker.class); + coldstartQueue = mock(ADColdStartWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); anomalyDetectionIndices = mock(ADIndexManagement.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(EntityCacheProvider.class); entityCache = mock(EntityCache.class); when(cacheProvider.get()).thenReturn(entityCache); when(entityCache.hostIfPossible(any(), any())).thenReturn(true); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, modelManager, checkpoint, coldstartQueue, @@ -174,7 +171,7 @@ public void setUp() throws Exception { nodeStateManager, anomalyDetectionIndices, cacheProvider, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, adStats ); @@ -232,11 +229,9 @@ private void regularTestSetUp(RegularSetUpConfig config) { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(config.fullModel).build()); when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); if (config.fullModel) { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(modelManager.getResult(any(), any(), anyString(), any(), anyInt())).thenReturn(new ThresholdingResult(0, 1, 1)); } else { - when(modelManager.getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 0, 0)); + when(modelManager.getResult(any(), any(), anyString(), any(), anyInt())).thenReturn(new ThresholdingResult(0, 0, 0)); } List requests = new ArrayList<>(); @@ -531,21 +526,21 @@ public void testRemoveUnusedQueues() { ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME)).thenReturn(executorService); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, modelManager, checkpoint, coldstartQueue, @@ -553,7 +548,7 @@ public void testRemoveUnusedQueues() { nodeStateManager, anomalyDetectionIndices, cacheProvider, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, adStats ); @@ -561,10 +556,10 @@ public void testRemoveUnusedQueues() { regularTestSetUp(new RegularSetUpConfig.Builder().build()); assertTrue(!worker.isQueueEmpty()); - assertEquals(CheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); + assertEquals(ADCheckpointReadWorker.WORKER_NAME, worker.getWorkerName()); // make RequestQueue.expired return true - when(clock.instant()).thenReturn(Instant.now().plusSeconds(AnomalyDetectorSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); + when(clock.instant()).thenReturn(Instant.now().plusSeconds(TimeSeriesSettings.HOURLY_MAINTENANCE.getSeconds() + 1)); // removed the expired queue worker.maintenance(); @@ -583,21 +578,21 @@ public void testSettingUpdatable() { maintenanceSetup(); // can host two requests in the queue - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( 2000, 1, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, modelManager, checkpoint, coldstartQueue, @@ -605,7 +600,7 @@ public void testSettingUpdatable() { nodeStateManager, anomalyDetectionIndices, cacheProvider, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, adStats ); @@ -620,7 +615,7 @@ public void testSettingUpdatable() { Settings newSettings = Settings .builder() - .put(AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT.getKey(), "0.0001") + .put(AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT.getKey(), "0.0001") .build(); Settings.Builder target = Settings.builder(); clusterSettings.updateDynamicSettings(newSettings, target, Settings.builder(), "test"); @@ -633,24 +628,24 @@ public void testSettingUpdatable() { public void testOpenCircuitBreaker() { maintenanceSetup(); - ADCircuitBreakerService breaker = mock(ADCircuitBreakerService.class); + CircuitBreakerService breaker = mock(CircuitBreakerService.class); when(breaker.isOpen()).thenReturn(true); - worker = new CheckpointReadWorker( + worker = new ADCheckpointReadWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), breaker, threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, modelManager, checkpoint, coldstartQueue, @@ -658,7 +653,7 @@ public void testOpenCircuitBreaker() { nodeStateManager, anomalyDetectionIndices, cacheProvider, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, checkpointWriteQueue, adStats ); @@ -711,10 +706,10 @@ public void testChangePriority() { } public void testDetectorId() { - assertEquals(detectorId, request.getId()); + assertEquals(detectorId, request.getConfigId()); String newDetectorId = "456"; request.setDetectorId(newDetectorId); - assertEquals(newDetectorId, request.getId()); + assertEquals(newDetectorId, request.getConfigId()); } @SuppressWarnings("unchecked") @@ -736,13 +731,13 @@ public void testHostException() throws IOException { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector2)); return null; - }).when(nodeStateManager).getAnomalyDetector(eq(detectorId2), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(eq(detectorId2), any(ActionListener.class)); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(eq(detectorId), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(eq(detectorId), any(ActionListener.class)); doAnswer(invocation -> { MultiGetItemResponse[] items = new MultiGetItemResponse[2]; @@ -802,7 +797,7 @@ public void testFailToScore() { state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(modelManager.processEntityCheckpoint(any(), any(), anyString(), anyString(), anyInt())).thenReturn(state); - doThrow(new IllegalArgumentException()).when(modelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); + doThrow(new IllegalArgumentException()).when(modelManager).getResult(any(), any(), anyString(), any(), anyInt()); List requests = new ArrayList<>(); requests.add(request); @@ -811,7 +806,7 @@ public void testFailToScore() { verify(resultWriteQueue, never()).put(any()); verify(checkpointWriteQueue, never()).write(any(), anyBoolean(), any()); verify(coldstartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java index 97e8370bf..7ae9cd3fd 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/CheckpointWriteWorkerTests.java @@ -46,11 +46,7 @@ import org.opensearch.action.bulk.BulkItemResponse.Failure; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; @@ -64,17 +60,18 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; public class CheckpointWriteWorkerTests extends AbstractRateLimitingTest { - CheckpointWriteWorker worker; + ADCheckpointWriteWorker worker; - CheckpointDao checkpoint; + ADCheckpointDao checkpoint; ClusterService clusterService; - ModelState state; + ADModelState state; @Override @SuppressWarnings("unchecked") @@ -88,7 +85,7 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE ) @@ -97,33 +94,33 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - checkpoint = mock(CheckpointDao.class); + checkpoint = mock(ADCheckpointDao.class); Map checkpointMap = new HashMap<>(); checkpointMap.put(CommonName.FIELD_MODEL, "a"); when(checkpoint.toIndexSource(any())).thenReturn(checkpointMap); when(checkpoint.shouldSave(any(), anyBoolean(), any(), any())).thenReturn(true); // Integer.MAX_VALUE makes a huge heap - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, checkpoint, ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + TimeSeriesSettings.HOURLY_MAINTENANCE ); state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); @@ -164,7 +161,7 @@ public void testTriggerSaveAll() { return null; }).when(checkpoint).batchWrite(any(), any()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); @@ -208,26 +205,26 @@ public void testTriggerAutoFlush() throws InterruptedException { // Integer.MAX_VALUE makes a huge heap // create a worker to use mockThreadPool - worker = new CheckpointWriteWorker( + worker = new ADCheckpointWriteWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.CHECKPOINT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), mockThreadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, checkpoint, ADCommonName.CHECKPOINT_INDEX_NAME, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + TimeSeriesSettings.HOURLY_MAINTENANCE ); // our concurrency is 2, so first 2 requests cause two batches. And the @@ -237,7 +234,7 @@ public void testTriggerAutoFlush() throws InterruptedException { // CHECKPOINT_WRITE_QUEUE_BATCH_SIZE is the largest batch size int numberOfRequests = 2 * AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE.getDefault(Settings.EMPTY) + 1; for (int i = 0; i < numberOfRequests; i++) { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); + ADModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().build()); worker.write(state, true, RequestPriority.MEDIUM); } @@ -266,7 +263,7 @@ public void testOverloaded() { worker.write(state, true, RequestPriority.MEDIUM); verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchRejectedExecutionException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchRejectedExecutionException.class)); } public void testRetryException() { @@ -280,7 +277,7 @@ public void testRetryException() { worker.write(state, true, RequestPriority.MEDIUM); // we don't retry checkpoint write verify(checkpoint, times(1)).batchWrite(any(), any()); - verify(nodeStateManager, times(1)).setException(eq(state.getId()), any(OpenSearchStatusException.class)); + verify(nodeStateManager, times(1)).setException(eq(state.getConfigId()), any(OpenSearchStatusException.class)); } /** @@ -308,7 +305,7 @@ public void testFailedRequest() { @SuppressWarnings("unchecked") public void testEmptyTimeStamp() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.MIN); worker.write(state, false, RequestPriority.MEDIUM); @@ -317,7 +314,7 @@ public void testEmptyTimeStamp() { @SuppressWarnings("unchecked") public void testTooSoonToSaveSingleWrite() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); worker.write(state, false, RequestPriority.MEDIUM); @@ -326,10 +323,10 @@ public void testTooSoonToSaveSingleWrite() { @SuppressWarnings("unchecked") public void testTooSoonToSaveWriteAll() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, false, RequestPriority.MEDIUM); @@ -339,7 +336,7 @@ public void testTooSoonToSaveWriteAll() { @SuppressWarnings("unchecked") public void testEmptyModel() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); when(state.getModel()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -349,11 +346,11 @@ public void testEmptyModel() { @SuppressWarnings("unchecked") public void testEmptyModelId() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); + createFromValueOnlySamples model = mock(createFromValueOnlySamples.class); when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn("1"); + when(state.getConfigId()).thenReturn("1"); when(state.getModelId()).thenReturn(null); worker.write(state, true, RequestPriority.MEDIUM); @@ -362,11 +359,11 @@ public void testEmptyModelId() { @SuppressWarnings("unchecked") public void testEmptyDetectorId() { - ModelState state = mock(ModelState.class); + ADModelState state = mock(ADModelState.class); when(state.getLastCheckpointTime()).thenReturn(Instant.now()); - EntityModel model = mock(EntityModel.class); + createFromValueOnlySamples model = mock(createFromValueOnlySamples.class); when(state.getModel()).thenReturn(model); - when(state.getId()).thenReturn(null); + when(state.getConfigId()).thenReturn(null); when(state.getModelId()).thenReturn("a"); worker.write(state, true, RequestPriority.MEDIUM); @@ -379,7 +376,7 @@ public void testDetectorNotAvailableSingleWrite() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.empty()); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); worker.write(state, true, RequestPriority.MEDIUM); verify(checkpoint, never()).batchWrite(any(), any()); @@ -391,9 +388,9 @@ public void testDetectorNotAvailableWriteAll() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.empty()); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); - List> states = new ArrayList<>(); + List> states = new ArrayList<>(); states.add(state); worker.writeAll(states, detectorId, true, RequestPriority.MEDIUM); verify(checkpoint, never()).batchWrite(any(), any()); @@ -405,7 +402,7 @@ public void testDetectorFetchException() { ActionListener> listener = invocation.getArgument(1); listener.onFailure(new RuntimeException()); return null; - }).when(nodeStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(any(String.class), any(ActionListener.class)); worker.write(state, true, RequestPriority.MEDIUM); verify(checkpoint, never()).batchWrite(any(), any()); diff --git a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java index f4af298c8..fc95120c8 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ColdEntityWorkerTests.java @@ -27,17 +27,20 @@ import java.util.List; import java.util.Random; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.EntityFeatureRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; public class ColdEntityWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - ColdEntityWorker coldWorker; - CheckpointReadWorker readWorker; + ADColdEntityWorker coldWorker; + ADCheckpointReadWorker readWorker; EntityFeatureRequest request, request2, invalidRequest; List requests; @@ -53,8 +56,8 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE ) ) @@ -62,25 +65,25 @@ public void setUp() throws Exception { ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - readWorker = mock(CheckpointReadWorker.class); + readWorker = mock(ADCheckpointReadWorker.class); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, settings, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, readWorker, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager ); @@ -99,7 +102,7 @@ public void setUp() throws Exception { TimeValue value = invocation.getArgument(1); // since we have only 1 request each time - long expectedExecutionPerRequestMilli = AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS + long expectedExecutionPerRequestMilli = AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS .getDefault(Settings.EMPTY); long delay = value.getMillis(); assertTrue(delay == expectedExecutionPerRequestMilli); @@ -143,8 +146,8 @@ public void testDelay() { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE ) ) @@ -153,22 +156,22 @@ public void testDelay() { when(clusterService.getClusterSettings()).thenReturn(clusterSettings); // Integer.MAX_VALUE makes a huge heap - coldWorker = new ColdEntityWorker( + coldWorker = new ADColdEntityWorker( Integer.MAX_VALUE, AnomalyDetectorSettings.ENTITY_FEATURE_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, readWorker, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager ); diff --git a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java index 5580b5f30..b5691f978 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/EntityColdStartWorkerTests.java @@ -29,25 +29,29 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelState; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.ml.createFromValueOnlySamples; +import org.opensearch.timeseries.ratelimit.EntityRequest; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.settings.TimeSeriesSettings; + import test.org.opensearch.ad.util.MLUtil; public class EntityColdStartWorkerTests extends AbstractRateLimitingTest { ClusterService clusterService; - EntityColdStartWorker worker; - EntityColdStarter entityColdStarter; - CacheProvider cacheProvider; + ADColdStartWorker worker; + ADEntityColdStart entityColdStarter; + HCCacheProvider cacheProvider; @Override public void setUp() throws Exception { @@ -60,36 +64,36 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY ) ) ) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - entityColdStarter = mock(EntityColdStarter.class); + entityColdStarter = mock(ADEntityColdStart.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(HCCacheProvider.class); // Integer.MAX_VALUE makes a huge heap - worker = new EntityColdStartWorker( + worker = new ADColdStartWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.ENTITY_REQUEST_SIZE_IN_BYTES, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.FEATURE_REQUEST_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, entityColdStarter, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, nodeStateManager, cacheProvider ); @@ -151,7 +155,7 @@ public void testModelHosted() { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); - ModelState state = invocation.getArgument(2); + ADModelState state = invocation.getArgument(2); state.setModel(MLUtil.createNonEmptyModel(detectorId)); listener.onResponse(null); diff --git a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java index 4b46311c6..67348c08c 100644 --- a/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java +++ b/src/test/java/org/opensearch/ad/ratelimit/ResultWriteWorkerTests.java @@ -34,13 +34,11 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionListener; import org.opensearch.action.index.IndexRequest; -import org.opensearch.ad.breaker.ADCircuitBreakerService; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.ad.transport.handler.MultiEntityResultHandler; +import org.opensearch.ad.transport.handler.ADIndexMemoryPressureAwareResultHandler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -49,12 +47,17 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.settings.TimeSeriesSettings; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.timeseries.util.RestHandlerUtils; public class ResultWriteWorkerTests extends AbstractRateLimitingTest { - ResultWriteWorker resultWriteQueue; + ADResultWriteWorker resultWriteQueue; ClusterService clusterService; - MultiEntityResultHandler resultHandler; + ADIndexMemoryPressureAwareResultHandler resultHandler; AnomalyResult detectResult; @Override @@ -69,7 +72,7 @@ public void setUp() throws Exception { new HashSet<>( Arrays .asList( - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE ) @@ -81,27 +84,27 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); setUpADThreadPool(threadPool); - resultHandler = mock(MultiEntityResultHandler.class); + resultHandler = mock(ADIndexMemoryPressureAwareResultHandler.class); - resultWriteQueue = new ResultWriteWorker( + resultWriteQueue = new ADResultWriteWorker( Integer.MAX_VALUE, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + TimeSeriesSettings.RESULT_WRITE_QUEUE_SIZE_IN_BYTES, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, clusterService, new Random(42), - mock(ADCircuitBreakerService.class), + mock(CircuitBreakerService.class), threadPool, Settings.EMPTY, - AnomalyDetectorSettings.MAX_QUEUED_TASKS_RATIO, + TimeSeriesSettings.MAX_QUEUED_TASKS_RATIO, clock, - AnomalyDetectorSettings.MEDIUM_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.LOW_SEGMENT_PRUNE_RATIO, - AnomalyDetectorSettings.MAINTENANCE_FREQ_CONSTANT, - AnomalyDetectorSettings.QUEUE_MAINTENANCE, + TimeSeriesSettings.MEDIUM_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.LOW_SEGMENT_PRUNE_RATIO, + TimeSeriesSettings.MAINTENANCE_FREQ_CONSTANT, + TimeSeriesSettings.QUEUE_MAINTENANCE, resultHandler, xContentRegistry(), nodeStateManager, - AnomalyDetectorSettings.HOURLY_MAINTENANCE + TimeSeriesSettings.HOURLY_MAINTENANCE ); detectResult = TestHelpers.randomHCADAnomalyDetectResult(0.8, Double.NaN, null); @@ -110,7 +113,7 @@ public void setUp() throws Exception { public void testRegular() { List retryRequests = new ArrayList<>(); - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); ResultWriteRequest resultWriteRequest = new ResultWriteRequest( @@ -123,7 +126,7 @@ public void testRegular() { request.add(resultWriteRequest); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onResponse(resp); return null; }).when(resultHandler).flush(any(), any()); @@ -142,7 +145,7 @@ public void testSingleRetryRequest() throws IOException { retryRequests.add(indexRequest); } - ADResultBulkResponse resp = new ADResultBulkResponse(retryRequests); + ResultBulkResponse resp = new ResultBulkResponse(retryRequests); ADResultBulkRequest request = new ADResultBulkRequest(); ResultWriteRequest resultWriteRequest = new ResultWriteRequest( @@ -156,9 +159,9 @@ public void testSingleRetryRequest() throws IOException { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onResponse(resp); @@ -175,9 +178,9 @@ public void testSingleRetryRequest() throws IOException { public void testRetryException() { final AtomicBoolean retried = new AtomicBoolean(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); if (retried.get()) { - listener.onResponse(new ADResultBulkResponse()); + listener.onResponse(new ResultBulkResponse()); } else { retried.set(true); listener.onFailure(new OpenSearchStatusException("blah", RestStatus.REQUEST_TIMEOUT)); @@ -194,7 +197,7 @@ public void testRetryException() { public void testOverloaded() { doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(1); listener.onFailure(new OpenSearchRejectedExecutionException("blah", true)); return null; diff --git a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java index fb1ccc1e4..43bc955e1 100644 --- a/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java +++ b/src/test/java/org/opensearch/ad/rest/ADRestTestUtils.java @@ -40,13 +40,14 @@ import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TimeSeriesTask; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -260,12 +261,12 @@ public static List searchLatestAdTaskOfDetector(RestClient client, Strin for (Object adTaskResponse : adTaskResponses) { String id = (String) ((Map) adTaskResponse).get("_id"); Map source = (Map) ((Map) adTaskResponse).get("_source"); - String state = (String) source.get(ADTask.STATE_FIELD); + String state = (String) source.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) source.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) source.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) source.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) source.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) source.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) source.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) source.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) source.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) source.get(TimeSeriesTask.COORDINATING_NODE_FIELD); ADTask adTask = ADTask .builder() .taskId(id) @@ -351,12 +352,12 @@ public static Map getDetectorWithJobAndTask(RestClient client, S Map jobMap = (Map) responseMap.get(ANOMALY_DETECTOR_JOB); if (jobMap != null) { - String jobName = (String) jobMap.get(AnomalyDetectorJob.NAME_FIELD); - boolean enabled = (boolean) jobMap.get(AnomalyDetectorJob.IS_ENABLED_FIELD); - long enabledTime = (long) jobMap.get(AnomalyDetectorJob.ENABLED_TIME_FIELD); - long lastUpdateTime = (long) jobMap.get(AnomalyDetectorJob.LAST_UPDATE_TIME_FIELD); + String jobName = (String) jobMap.get(Job.NAME_FIELD); + boolean enabled = (boolean) jobMap.get(Job.IS_ENABLED_FIELD); + long enabledTime = (long) jobMap.get(Job.ENABLED_TIME_FIELD); + long lastUpdateTime = (long) jobMap.get(Job.LAST_UPDATE_TIME_FIELD); - AnomalyDetectorJob job = new AnomalyDetectorJob( + Job job = new Job( jobName, null, null, @@ -387,13 +388,13 @@ public static Map getDetectorWithJobAndTask(RestClient client, S } private static ADTask parseAdTask(Map taskMap) { - String id = (String) taskMap.get(ADTask.TASK_ID_FIELD); - String state = (String) taskMap.get(ADTask.STATE_FIELD); + String id = (String) taskMap.get(TimeSeriesTask.TASK_ID_FIELD); + String state = (String) taskMap.get(TimeSeriesTask.STATE_FIELD); String parsedDetectorId = (String) taskMap.get(ADTask.DETECTOR_ID_FIELD); - Double taskProgress = (Double) taskMap.get(ADTask.TASK_PROGRESS_FIELD); - Double initProgress = (Double) taskMap.get(ADTask.INIT_PROGRESS_FIELD); - String parsedTaskType = (String) taskMap.get(ADTask.TASK_TYPE_FIELD); - String coordinatingNode = (String) taskMap.get(ADTask.COORDINATING_NODE_FIELD); + Double taskProgress = (Double) taskMap.get(TimeSeriesTask.TASK_PROGRESS_FIELD); + Double initProgress = (Double) taskMap.get(TimeSeriesTask.INIT_PROGRESS_FIELD); + String parsedTaskType = (String) taskMap.get(TimeSeriesTask.TASK_TYPE_FIELD); + String coordinatingNode = (String) taskMap.get(TimeSeriesTask.COORDINATING_NODE_FIELD); return ADTask .builder() .taskId(id) diff --git a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java index 390e68ef7..9baecdbed 100644 --- a/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/AnomalyDetectorRestApiIT.java @@ -14,7 +14,6 @@ import static org.hamcrest.Matchers.containsString; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.DUPLICATE_DETECTOR_MSG; import static org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler.NO_DOCS_IN_USER_INDEX_MSG; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import java.io.IOException; import java.time.Instant; @@ -36,7 +35,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.rest.handler.AbstractAnomalyDetectorActionHandler; import org.opensearch.ad.settings.ADEnabledSetting; @@ -54,6 +52,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.google.common.collect.ImmutableList; @@ -510,7 +509,7 @@ public void testPreviewAnomalyDetector() throws Exception { .makeRequest( client(), "POST", - String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getDetectorId()), + String.format(Locale.ROOT, TestHelpers.AD_BASE_PREVIEW_URI, input.getConfigId()), ImmutableMap.of(), TestHelpers.toHttpEntity(input), null @@ -820,8 +819,8 @@ public void testGetDetectorWithAdJob() throws Exception { ToXContentObject[] results = getAnomalyDetector(detector.getId(), true, client()); assertEquals("Incorrect Location header", detector, results[0]); - assertEquals("Incorrect detector job name", detector.getId(), ((AnomalyDetectorJob) results[1]).getName()); - assertTrue(((AnomalyDetectorJob) results[1]).isEnabled()); + assertEquals("Incorrect detector job name", detector.getId(), ((Job) results[1]).getName()); + assertTrue(((Job) results[1]).isEnabled()); results = getAnomalyDetector(detector.getId(), false, client()); assertEquals("Incorrect Location header", detector, results[0]); @@ -895,7 +894,7 @@ public void testStartAdJobWithNonexistingDetector() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), @@ -997,7 +996,7 @@ public void testStopNonExistingAdJob() throws Exception { TestHelpers .assertFailWith( ResponseException.class, - FAIL_TO_FIND_CONFIG_MSG, + CommonMessages.FAIL_TO_FIND_CONFIG_MSG, () -> TestHelpers .makeRequest( client(), diff --git a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java index 7d0be2ae9..8e9f1850d 100644 --- a/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java +++ b/src/test/java/org/opensearch/ad/rest/HistoricalAnalysisRestApiIT.java @@ -17,8 +17,8 @@ import static org.opensearch.timeseries.TestHelpers.AD_BASE_STATS_URI; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; import static org.opensearch.timeseries.stats.StatNames.AD_TOTAL_BATCH_TASK_EXECUTION_COUNT; -import static org.opensearch.timeseries.stats.StatNames.MULTI_ENTITY_DETECTOR_COUNT; -import static org.opensearch.timeseries.stats.StatNames.SINGLE_ENTITY_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.HC_DETECTOR_COUNT; +import static org.opensearch.timeseries.stats.StatNames.SINGLE_STREAM_DETECTOR_COUNT; import java.io.IOException; import java.util.List; @@ -34,14 +34,14 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.model.TaskState; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -118,8 +118,8 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul // get task profile ADTaskProfile adTaskProfile = waitUntilGetTaskProfile(detectorId); if (categoryFieldSize > 0) { - if (!ADTaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { - adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(ADTaskState.RUNNING.name())).get(0); + if (!TaskState.RUNNING.name().equals(adTaskProfile.getAdTask().getState())) { + adTaskProfile = (ADTaskProfile) waitUntilTaskReachState(detectorId, ImmutableSet.of(TaskState.RUNNING.name())).get(0); } assertEquals((int) Math.pow(categoryFieldDocCount, categoryFieldSize), adTaskProfile.getTotalEntitiesCount().intValue()); assertTrue(adTaskProfile.getPendingEntitiesCount() > 0); @@ -133,7 +133,7 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul Response statsResponse = TestHelpers.makeRequest(client(), "GET", AD_BASE_STATS_URI, ImmutableMap.of(), "", null); String statsResult = EntityUtils.toString(statsResponse.getEntity()); Map stringObjectMap = TestHelpers.parseStatsResult(statsResult); - String detectorCountState = categoryFieldSize > 0 ? MULTI_ENTITY_DETECTOR_COUNT.getName() : SINGLE_ENTITY_DETECTOR_COUNT.getName(); + String detectorCountState = categoryFieldSize > 0 ? HC_DETECTOR_COUNT.getName() : SINGLE_STREAM_DETECTOR_COUNT.getName(); assertTrue((long) stringObjectMap.get(detectorCountState) > 0); Map nodes = (Map) stringObjectMap.get("nodes"); long totalBatchTaskExecution = 0; @@ -146,7 +146,7 @@ private List startHistoricalAnalysis(int categoryFieldSize, String resul // get detector with AD task ToXContentObject[] result = getHistoricalAnomalyDetector(detectorId, true, client()); AnomalyDetector parsedDetector = (AnomalyDetector) result[0]; - AnomalyDetectorJob parsedJob = (AnomalyDetectorJob) result[1]; + Job parsedJob = (Job) result[1]; ADTask parsedADTask = (ADTask) result[2]; assertNull(parsedJob); assertNotNull(parsedDetector); @@ -172,7 +172,7 @@ public void testStopHistoricalAnalysis() throws Exception { assertEquals(RestStatus.OK, TestHelpers.restStatus(stopDetectorResponse)); // get task profile - checkIfTaskCanFinishCorrectly(detectorId, taskId, ImmutableSet.of(ADTaskState.STOPPED.name())); + checkIfTaskCanFinishCorrectly(detectorId, taskId, ImmutableSet.of(TaskState.STOPPED.name())); updateClusterSettings(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1); waitUntilTaskDone(detectorId); diff --git a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java index 1c0758ebf..7b94bc7e0 100644 --- a/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java +++ b/src/test/java/org/opensearch/ad/rest/SecureADRestIT.java @@ -225,7 +225,7 @@ public void testGetApiFilterByEnabledForAdmin() throws IOException { AnomalyDetector aliceDetector = createRandomAnomalyDetector(false, false, aliceClient); enableFilterBy(); confirmingClientIsAdmin(); - AnomalyDetector detector = getAnomalyDetector(aliceDetector.getId(), client()); + AnomalyDetector detector = getAnomalyDetector(aliceDetector.getConfigId(), client()); Assert .assertArrayEquals( "User backend role of detector doesn't change", @@ -240,7 +240,7 @@ public void testUpdateApiFilterByEnabledForAdmin() throws IOException { enableFilterBy(); AnomalyDetector newDetector = new AnomalyDetector( - aliceDetector.getId(), + aliceDetector.getConfigId(), aliceDetector.getVersion(), aliceDetector.getName(), randomAlphaOfLength(10), diff --git a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java index 59eba777c..4e77ece51 100644 --- a/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java +++ b/src/test/java/org/opensearch/ad/rest/handler/IndexAnomalyDetectorJobActionHandlerTests.java @@ -21,7 +21,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.action.DocWriteResponse.Result.CREATED; -import static org.opensearch.ad.constant.ADCommonMessages.CAN_NOT_FIND_LATEST_TASK; import java.io.IOException; import java.util.Arrays; @@ -34,8 +33,8 @@ import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.ExecuteADResultResponseRecorder; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; @@ -44,12 +43,10 @@ import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.AnomalyResultAction; import org.opensearch.ad.transport.AnomalyResultResponse; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.transport.handler.AnomalyIndexHandler; import org.opensearch.client.Client; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -59,7 +56,11 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.ResourceNotFoundException; +import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.model.Feature; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.transport.JobResponse; +import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -84,9 +85,9 @@ public class IndexAnomalyDetectorJobActionHandlerTests extends OpenSearchTestCas private ExecuteADResultResponseRecorder recorder; private Client client; - private IndexAnomalyDetectorJobActionHandler handler; - private AnomalyIndexHandler anomalyResultHandler; - private NodeStateManager nodeStateManager; + private IndexJobActionHandler handler; + private ResultIndexingHandler anomalyResultHandler; + private ADNodeStateManager nodeStateManager; private ADTaskCacheManager adTaskCacheManager; @BeforeClass @@ -146,9 +147,9 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); - ActionListener listener = (ActionListener) args[4]; + ActionListener listener = (ActionListener) args[4]; - AnomalyDetectorJobResponse response = mock(AnomalyDetectorJobResponse.class); + JobResponse response = mock(JobResponse.class); listener.onResponse(response); return null; @@ -156,9 +157,9 @@ public void setUp() throws Exception { threadPool = mock(ThreadPool.class); - anomalyResultHandler = mock(AnomalyIndexHandler.class); + anomalyResultHandler = mock(ResultIndexingHandler.class); - nodeStateManager = mock(NodeStateManager.class); + nodeStateManager = mock(ADNodeStateManager.class); adTaskCacheManager = mock(ADTaskCacheManager.class); when(adTaskCacheManager.hasQueriedResultIndex(anyString())).thenReturn(true); @@ -175,7 +176,7 @@ public void setUp() throws Exception { 32 ); - handler = new IndexAnomalyDetectorJobActionHandler( + handler = new IndexJobActionHandler( client, anomalyDetectionIndices, detectorId, @@ -193,9 +194,9 @@ public void setUp() throws Exception { public void testDelayHCProfile() { when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(false); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(client, times(1)).get(any(), any()); verify(client, times(1)).execute(any(), any(), any()); @@ -220,9 +221,9 @@ public void testNoDelayHCProfile() { when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); @@ -246,9 +247,9 @@ public void testHCProfileException() { when(adTaskManager.isHCRealtimeTaskStartInitializing(anyString())).thenReturn(true); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); @@ -278,14 +279,14 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNodeResourceNotFoundExcept Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[5]; - listener.onFailure(new ResourceNotFoundException(CAN_NOT_FIND_LATEST_TASK)); + listener.onFailure(new ResourceNotFoundException(CommonMessages.CAN_NOT_FIND_LATEST_TASK)); return null; }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); @@ -321,9 +322,9 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { return null; }).when(adTaskManager).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(client, times(1)).get(any(), any()); verify(client, times(2)).execute(any(), any(), any()); @@ -331,7 +332,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingException() { verify(adTaskManager, times(1)).isHCRealtimeTaskStartInitializing(anyString()); verify(adTaskManager, times(1)).updateLatestRealtimeTaskOnCoordinatingNode(any(), any(), any(), any(), any(), any()); verify(adTaskManager, never()).removeRealtimeTaskCache(anyString()); - verify(adTaskManager, times(1)).skipUpdateHCRealtimeTask(anyString(), anyString()); + verify(adTaskManager, times(1)).skipUpdateRealtimeTask(anyString(), anyString()); verify(threadPool, never()).schedule(any(), any(), any()); verify(listener, times(1)).onResponse(any()); } @@ -347,7 +348,7 @@ public void testIndexException() throws IOException { return null; }).when(client).execute(any(AnomalyResultAction.class), any(), any()); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); AggregationBuilder aggregationBuilder = TestHelpers .parseAggregation("{\"test\":{\"max\":{\"field\":\"" + MockSimpleLog.VALUE_FIELD + "\"}}}"); Feature feature = new Feature(randomAlphaOfLength(5), randomAlphaOfLength(10), true, aggregationBuilder); @@ -361,7 +362,7 @@ public void testIndexException() throws IOException { ADCommonName.CUSTOM_RESULT_INDEX_PREFIX + "index" ); when(anomalyDetectionIndices.doesIndexExist(anyString())).thenReturn(false); - handler.startAnomalyDetectorJob(detector, listener); + handler.startJob(detector, listener); verify(anomalyResultHandler, times(1)).index(any(), any(), eq(null)); verify(threadPool, times(1)).schedule(any(), any(), any()); } diff --git a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java index 72e336ea7..ba9c2b4fc 100644 --- a/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java +++ b/src/test/java/org/opensearch/ad/settings/AnomalyDetectorSettingsTests.java @@ -76,47 +76,47 @@ public void testAllOpenSearchSettingsReturned() { .containsAll( Arrays .asList( - AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, - AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS, + AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS, AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, - AnomalyDetectorSettings.REQUEST_TIMEOUT, + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT, AnomalyDetectorSettings.DETECTION_INTERVAL, AnomalyDetectorSettings.DETECTION_WINDOW_DELAY, AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD, AnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS_PER_SHARD, - AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE, - AnomalyDetectorSettings.COOLDOWN_MINUTES, - AnomalyDetectorSettings.BACKOFF_MINUTES, + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE, + AnomalyDetectorSettings.AD_COOLDOWN_MINUTES, + AnomalyDetectorSettings.AD_BACKOFF_MINUTES, AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY, AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF, AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD, - AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE, + AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW, AnomalyDetectorSettings.AD_INDEX_PRESSURE_SOFT_LIMIT, AnomalyDetectorSettings.AD_INDEX_PRESSURE_HARD_LIMIT, AnomalyDetectorSettings.AD_MAX_PRIMARY_SHARDS, - AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE, AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS, AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, AnomalyDetectorSettings.BATCH_TASK_PIECE_SIZE, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_CONCURRENCY, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_CONCURRENCY, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_CONCURRENCY, AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_BATCH_SIZE, AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_BATCH_SIZE, - AnomalyDetectorSettings.DEDICATED_CACHE_SIZE, - AnomalyDetectorSettings.COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, - AnomalyDetectorSettings.EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - AnomalyDetectorSettings.PAGE_SIZE + AnomalyDetectorSettings.AD_DEDICATED_CACHE_SIZE, + AnomalyDetectorSettings.AD_COLD_ENTITY_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_READ_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_CHECKPOINT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_RESULT_WRITE_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_ENTITY_COLD_START_QUEUE_MAX_HEAP_PERCENT, + AnomalyDetectorSettings.AD_EXPECTED_COLD_ENTITY_EXECUTION_TIME_IN_MILLISECS, + AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.AD_PAGE_SIZE ) ) ); @@ -124,11 +124,11 @@ public void testAllOpenSearchSettingsReturned() { public void testAllLegacyOpenDistroSettingsFallback() { assertEquals( - AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(Settings.EMPTY) ); assertEquals( @@ -136,7 +136,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.REQUEST_TIMEOUT.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(Settings.EMPTY) ); assertEquals( @@ -152,15 +152,15 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_COOLDOWN_MINUTES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_BACKOFF_MINUTES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(Settings.EMPTY) ); assertEquals( @@ -176,7 +176,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(Settings.EMPTY) ); // MAX_ENTITIES_FOR_PREVIEW does not use legacy setting @@ -188,7 +188,7 @@ public void testAllLegacyOpenDistroSettingsFallback() { LegacyOpenDistroAnomalyDetectorSettings.MAX_PRIMARY_SHARDS.get(Settings.EMPTY) ); assertEquals( - AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY), + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY), LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(Settings.EMPTY) ); assertEquals( @@ -211,15 +211,15 @@ public void testAllLegacyOpenDistroSettingsFallback() { public void testSettingsGetValue() { Settings settings = Settings.builder().put("plugins.anomaly_detection.request_timeout", "42s").build(); - assertEquals(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(42)); + assertEquals(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(42)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(10)); settings = Settings.builder().put("plugins.anomaly_detection.max_anomaly_detectors", 99).build(); - assertEquals(AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(99)); + assertEquals(AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(99)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(1000)); settings = Settings.builder().put("plugins.anomaly_detection.max_multi_entity_anomaly_detectors", 98).build(); - assertEquals(AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(98)); + assertEquals(AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS.get(settings), Integer.valueOf(98)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(10)); settings = Settings.builder().put("plugins.anomaly_detection.max_anomaly_features", 7).build(); @@ -253,15 +253,15 @@ public void testSettingsGetValue() { assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(30)); settings = Settings.builder().put("plugins.anomaly_detection.max_retry_for_unresponsive_node", 91).build(); - assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(91)); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(91)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(5)); settings = Settings.builder().put("plugins.anomaly_detection.cooldown_minutes", TimeValue.timeValueMinutes(90)).build(); - assertEquals(AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(90)); + assertEquals(AnomalyDetectorSettings.AD_COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(90)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(5)); settings = Settings.builder().put("plugins.anomaly_detection.backoff_minutes", TimeValue.timeValueMinutes(89)).build(); - assertEquals(AnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(89)); + assertEquals(AnomalyDetectorSettings.AD_BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(89)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(15)); settings = Settings.builder().put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(88)).build(); @@ -273,19 +273,19 @@ public void testSettingsGetValue() { assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_BACKOFF.get(settings), Integer.valueOf(3)); settings = Settings.builder().put("plugins.anomaly_detection.max_retry_for_end_run_exception", 86).build(); - assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(86)); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(86)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(6)); settings = Settings.builder().put("plugins.anomaly_detection.filter_by_backend_roles", true).build(); - assertEquals(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); + assertEquals(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(false)); settings = Settings.builder().put("plugins.anomaly_detection.model_max_size_percent", 0.3).build(); - assertEquals(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.3)); + assertEquals(AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.3)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.1)); settings = Settings.builder().put("plugins.anomaly_detection.max_entities_per_query", 83).build(); - assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(83)); + assertEquals(AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(83)); assertEquals(LegacyOpenDistroAnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY.get(settings), Integer.valueOf(1000)); settings = Settings.builder().put("plugins.anomaly_detection.max_entities_for_preview", 22).build(); @@ -350,24 +350,24 @@ public void testSettingsGetValueWithLegacyFallback() { .put("opendistro.anomaly_detection.batch_task_piece_interval_seconds", 26) .build(); - assertEquals(AnomalyDetectorSettings.MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(1)); - assertEquals(AnomalyDetectorSettings.MAX_MULTI_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(2)); + assertEquals(AnomalyDetectorSettings.AD_MAX_SINGLE_ENTITY_ANOMALY_DETECTORS.get(settings), Integer.valueOf(1)); + assertEquals(AnomalyDetectorSettings.AD_MAX_HC_ANOMALY_DETECTORS.get(settings), Integer.valueOf(2)); assertEquals(AnomalyDetectorSettings.MAX_ANOMALY_FEATURES.get(settings), Integer.valueOf(3)); - assertEquals(AnomalyDetectorSettings.REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(4)); + assertEquals(AnomalyDetectorSettings.AD_REQUEST_TIMEOUT.get(settings), TimeValue.timeValueSeconds(4)); assertEquals(AnomalyDetectorSettings.DETECTION_INTERVAL.get(settings), TimeValue.timeValueMinutes(5)); assertEquals(AnomalyDetectorSettings.DETECTION_WINDOW_DELAY.get(settings), TimeValue.timeValueMinutes(6)); assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_ROLLOVER_PERIOD.get(settings), TimeValue.timeValueHours(7)); // AD_RESULT_HISTORY_MAX_DOCS is removed in the new release assertEquals(LegacyOpenDistroAnomalyDetectorSettings.AD_RESULT_HISTORY_MAX_DOCS.get(settings), Long.valueOf(8L)); assertEquals(AnomalyDetectorSettings.AD_RESULT_HISTORY_RETENTION_PERIOD.get(settings), TimeValue.timeValueDays(9)); - assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(10)); - assertEquals(AnomalyDetectorSettings.COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(11)); - assertEquals(AnomalyDetectorSettings.BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(12)); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE.get(settings), Integer.valueOf(10)); + assertEquals(AnomalyDetectorSettings.AD_COOLDOWN_MINUTES.get(settings), TimeValue.timeValueMinutes(11)); + assertEquals(AnomalyDetectorSettings.AD_BACKOFF_MINUTES.get(settings), TimeValue.timeValueMinutes(12)); assertEquals(AnomalyDetectorSettings.AD_BACKOFF_INITIAL_DELAY.get(settings), TimeValue.timeValueMillis(13)); assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_BACKOFF.get(settings), Integer.valueOf(14)); - assertEquals(AnomalyDetectorSettings.MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(15)); - assertEquals(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); - assertEquals(AnomalyDetectorSettings.MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.6D)); + assertEquals(AnomalyDetectorSettings.AD_MAX_RETRY_FOR_END_RUN_EXCEPTION.get(settings), Integer.valueOf(15)); + assertEquals(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.get(settings), Boolean.valueOf(true)); + assertEquals(AnomalyDetectorSettings.AD_MODEL_MAX_SIZE_PERCENTAGE.get(settings), Double.valueOf(0.6D)); // MAX_ENTITIES_FOR_PREVIEW uses default instead of legacy fallback assertEquals(AnomalyDetectorSettings.MAX_ENTITIES_FOR_PREVIEW.get(settings), Integer.valueOf(5)); // INDEX_PRESSURE_SOFT_LIMIT uses default instead of legacy fallback diff --git a/src/test/java/org/opensearch/ad/stats/ADStatTests.java b/src/test/java/org/opensearch/ad/stats/ADStatTests.java index 1912f92ad..7ec161f1b 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatTests.java @@ -14,32 +14,33 @@ import java.util.function.Supplier; import org.junit.Test; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class ADStatTests extends OpenSearchTestCase { @Test public void testIsClusterLevel() { - ADStat stat1 = new ADStat<>(true, new TestSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(true, new TestSupplier()); assertTrue("isCluster returns the wrong value", stat1.isClusterLevel()); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertTrue("isCluster returns the wrong value", !stat2.isClusterLevel()); } @Test public void testGetValue() { - ADStat stat1 = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat stat1 = new TimeSeriesStat<>(false, new CounterSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); - ADStat stat2 = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat stat2 = new TimeSeriesStat<>(false, new TestSupplier()); assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); } @Test public void testSetValue() { - ADStat stat = new ADStat<>(false, new SettableSupplier()); + TimeSeriesStat stat = new TimeSeriesStat<>(false, new SettableSupplier()); assertEquals("GetValue returns the incorrect value", 0L, (long) (stat.getValue())); stat.setValue(10L); assertEquals("GetValue returns the incorrect value", 10L, (long) stat.getValue()); @@ -47,7 +48,7 @@ public void testSetValue() { @Test public void testIncrement() { - ADStat incrementStat = new ADStat<>(false, new CounterSupplier()); + TimeSeriesStat incrementStat = new TimeSeriesStat<>(false, new CounterSupplier()); for (Long i = 0L; i < 100; i++) { assertEquals("increment does not work", i, incrementStat.getValue()); @@ -55,7 +56,7 @@ public void testIncrement() { } // Ensure that no problems occur for a stat that cannot be incremented - ADStat nonIncStat = new ADStat<>(false, new TestSupplier()); + TimeSeriesStat nonIncStat = new TimeSeriesStat<>(false, new TestSupplier()); nonIncStat.increment(); } diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java index 194623bd5..a4ed70e68 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsResponseTests.java @@ -21,6 +21,7 @@ import org.opensearch.action.FailedNodeException; import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesResponse; +import org.opensearch.ad.transport.ADStatsResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; diff --git a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java index 0d8150683..ac69bb638 100644 --- a/src/test/java/org/opensearch/ad/stats/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/stats/ADStatsTests.java @@ -14,7 +14,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; import java.time.Clock; import java.util.ArrayList; @@ -30,21 +30,17 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ADModelState; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeSupplier; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -53,7 +49,7 @@ public class ADStatsTests extends OpenSearchTestCase { - private Map> statsMap; + private Map> statsMap; private ADStats adStats; private RandomCutForest rcf; private HybridThresholdingModel thresholdingModel; @@ -64,10 +60,10 @@ public class ADStatsTests extends OpenSearchTestCase { private Clock clock; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private EntityCacheProvider cacheProvider; @Before public void setup() { @@ -77,25 +73,34 @@ public void setup() { rcf = RandomCutForest.builder().dimensions(1).sampleSize(2).numberOfTrees(1).build(); thresholdingModel = new HybridThresholdingModel(1e-8, 1e-5, 200, 10_000, 2, 5_000_000); - List> modelsInformation = new ArrayList<>( + List> modelsInformation = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ADModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ADModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ADModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), + new ADModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f + ) ) ); when(modelManager.getAllModels()).thenReturn(modelsInformation); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); + List> entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); EntityCache cache = mock(EntityCache.class); when(cacheProvider.get()).thenReturn(cache); - when(cache.getAllModels()).thenReturn(entityModelsInformation); + when(cache.getAllModelStates()).thenReturn(entityModelsInformation); IndexUtils indexUtils = mock(IndexUtils.class); @@ -108,20 +113,23 @@ public void setup() { nodeStatName1 = "nodeStat1"; nodeStatName2 = "nodeStat2"; - Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + Settings settings = Settings.builder().put(AD_MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AD_MAX_MODEL_SIZE_PER_NODE))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); } }; @@ -135,11 +143,11 @@ public void testStatNamesGetNames() { @Test public void testGetStats() { - Map> stats = adStats.getStats(); + Map> stats = adStats.getStats(); assertEquals("getStats returns the incorrect number of stats", stats.size(), statsMap.size()); - for (Map.Entry> stat : stats.entrySet()) { + for (Map.Entry> stat : stats.entrySet()) { assertTrue( "getStats returns incorrect stats", adStats.getStats().containsKey(stat.getKey()) && adStats.getStats().get(stat.getKey()) == stat.getValue() @@ -149,7 +157,7 @@ public void testGetStats() { @Test public void testGetStat() { - ADStat stat = adStats.getStat(clusterStatName1); + TimeSeriesStat stat = adStats.getStat(clusterStatName1); assertTrue( "getStat returns incorrect stat", @@ -159,10 +167,10 @@ public void testGetStat() { @Test public void testGetNodeStats() { - Map> stats = adStats.getStats(); - Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); + Map> stats = adStats.getStats(); + Set> nodeStats = new HashSet<>(adStats.getNodeStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getNodeStats returns incorrect stat", (stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat)) @@ -172,10 +180,10 @@ public void testGetNodeStats() { @Test public void testGetClusterStats() { - Map> stats = adStats.getStats(); - Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); + Map> stats = adStats.getStats(); + Set> clusterStats = new HashSet<>(adStats.getClusterStats().values()); - for (ADStat stat : stats.values()) { + for (TimeSeriesStat stat : stats.values()) { assertTrue( "getClusterStats returns incorrect stat", (stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat)) diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java index 333d50ffe..3490e0318 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/CounterSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; public class CounterSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java index cfdf71188..409437490 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/IndexSupplierTests.java @@ -16,8 +16,9 @@ import org.junit.Before; import org.junit.Test; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.util.IndexUtils; public class IndexSupplierTests extends OpenSearchTestCase { private IndexUtils indexUtils; diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java index c0173593c..36e29a335 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/ModelsOnNodeSupplierTests.java @@ -13,8 +13,8 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; -import static org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.timeseries.stats.suppliers.ModelsOnNodeSupplier.MODEL_STATE_STAT_KEYS; import java.time.Clock; import java.util.ArrayList; @@ -30,16 +30,19 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.EntityModel; +import org.opensearch.ad.ml.ADModelManager; +import org.opensearch.ad.ml.ADModelState; import org.opensearch.ad.ml.HybridThresholdingModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.ml.ModelManager; +import org.opensearch.timeseries.ml.createFromValueOnlySamples; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeCountSupplier; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeSupplier; import test.org.opensearch.ad.util.MLUtil; import test.org.opensearch.ad.util.RandomModelStateConfig; @@ -49,15 +52,15 @@ public class ModelsOnNodeSupplierTests extends OpenSearchTestCase { private RandomCutForest rcf; private HybridThresholdingModel thresholdingModel; - private List> expectedResults; + private List> expectedResults; private Clock clock; - private List> entityModelsInformation; + private List> entityModelsInformation; @Mock - private ModelManager modelManager; + private ADModelManager modelManager; @Mock - private CacheProvider cacheProvider; + private HCCacheProvider cacheProvider; @Before public void setup() { @@ -70,17 +73,26 @@ public void setup() { expectedResults = new ArrayList<>( Arrays .asList( - new ModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), - new ModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), - new ModelState<>(thresholdingModel, "thr-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f) + new ADModelState<>(rcf, "rcf-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ADModelState<>(thresholdingModel, "thr-model-1", "detector-1", ModelManager.ModelType.RCF.getName(), clock, 0f), + new ADModelState<>(rcf, "rcf-model-2", "detector-2", ModelManager.ModelType.THRESHOLD.getName(), clock, 0f), + new ADModelState<>( + thresholdingModel, + "thr-model-2", + "detector-2", + ModelManager.ModelType.THRESHOLD.getName(), + clock, + 0f + ) ) ); when(modelManager.getAllModels()).thenReturn(expectedResults); - ModelState entityModel1 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); - ModelState entityModel2 = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState entityModel1 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState entityModel2 = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); entityModelsInformation = new ArrayList<>(Arrays.asList(entityModel1, entityModel2)); EntityCache cache = mock(EntityCache.class); @@ -90,11 +102,11 @@ public void setup() { @Test public void testGet() { - Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + Settings settings = Settings.builder().put(AD_MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AD_MAX_MODEL_SIZE_PER_NODE))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); diff --git a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java index 1cf1c9306..821871984 100644 --- a/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java +++ b/src/test/java/org/opensearch/ad/stats/suppliers/SettableSupplierTests.java @@ -13,6 +13,7 @@ import org.junit.Test; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; public class SettableSupplierTests extends OpenSearchTestCase { @Test diff --git a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java index ba9698d6a..ea311d97d 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskCacheManagerTests.java @@ -33,9 +33,7 @@ import org.junit.After; import org.junit.Before; -import org.opensearch.ad.MemoryTracker; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -46,11 +44,13 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.DuplicateTaskException; import org.opensearch.timeseries.common.exception.LimitExceededException; +import org.opensearch.timeseries.model.TaskState; +import org.opensearch.timeseries.task.RealtimeTaskCache; import com.google.common.collect.ImmutableList; public class ADTaskCacheManagerTests extends OpenSearchTestCase { - private MemoryTracker memoryTracker; + private ADMemoryTracker memoryTracker; private ADTaskCacheManager adTaskCacheManager; private ClusterService clusterService; private Settings settings; @@ -77,7 +77,7 @@ public void setUp() throws Exception { ) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - memoryTracker = mock(MemoryTracker.class); + memoryTracker = mock(ADMemoryTracker.class); adTaskCacheManager = new ADTaskCacheManager(settings, clusterService, memoryTracker); } @@ -113,12 +113,12 @@ public void testPutDuplicateTask() throws IOException { ADTask adTask2 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, adTask1.getExecutionEndTime(), adTask1.getStoppedBy(), adTask1.getId(), adTask1.getDetector(), - ADTaskType.HISTORICAL_SINGLE_ENTITY + ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR ); DuplicateTaskException e2 = expectThrows(DuplicateTaskException.class, () -> adTaskCacheManager.add(adTask2)); assertEquals(DETECTOR_IS_RUNNING, e2.getMessage()); @@ -137,7 +137,7 @@ public void testPutMultipleEntityTasks() throws IOException { ADTask adTask1 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detector.getId(), @@ -147,7 +147,7 @@ public void testPutMultipleEntityTasks() throws IOException { ADTask adTask2 = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detector.getId(), @@ -310,7 +310,7 @@ public void testPushBackEntity() throws IOException { public void testRealtimeTaskCache() { String detectorId1 = randomAlphaOfLength(10); - String newState = ADTaskState.INIT.name(); + String newState = TaskState.INIT.name(); Float newInitProgress = 0.0f; String newError = randomAlphaOfLength(5); assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); @@ -328,7 +328,7 @@ public void testRealtimeTaskCache() { adTaskCacheManager.updateRealtimeTaskCache(detectorId2, newState, newInitProgress, newError); assertEquals(2, adTaskCacheManager.getDetectorIdsInRealtimeTaskCache().length); - newState = ADTaskState.RUNNING.name(); + newState = TaskState.RUNNING.name(); newInitProgress = 1.0f; newError = "test error"; assertTrue(adTaskCacheManager.isRealtimeTaskChangeNeeded(detectorId1, newState, newInitProgress, newError)); @@ -349,12 +349,12 @@ public void testUpdateRealtimeTaskCache() { String detectorId = randomAlphaOfLength(5); adTaskCacheManager.initRealtimeTaskCache(detectorId, 60_000); adTaskCacheManager.updateRealtimeTaskCache(detectorId, null, null, null); - ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); assertNull(realtimeTaskCache.getState()); assertNull(realtimeTaskCache.getError()); assertNull(realtimeTaskCache.getInitProgress()); - String state = ADTaskState.RUNNING.name(); + String state = TaskState.RUNNING.name(); Float initProgress = 0.1f; String error = randomAlphaOfLength(5); adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); @@ -363,7 +363,7 @@ public void testUpdateRealtimeTaskCache() { assertEquals(error, realtimeTaskCache.getError()); assertEquals(initProgress, realtimeTaskCache.getInitProgress()); - state = ADTaskState.STOPPED.name(); + state = TaskState.STOPPED.name(); adTaskCacheManager.updateRealtimeTaskCache(detectorId, state, initProgress, error); realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); assertNull(realtimeTaskCache); @@ -379,10 +379,10 @@ public void testGetAndDecreaseEntityTaskLanes() throws IOException { public void testDeletedTask() { String taskId = randomAlphaOfLength(10); - adTaskCacheManager.addDeletedDetectorTask(taskId); - assertTrue(adTaskCacheManager.hasDeletedDetectorTask()); - assertEquals(taskId, adTaskCacheManager.pollDeletedDetectorTask()); - assertFalse(adTaskCacheManager.hasDeletedDetectorTask()); + adTaskCacheManager.addDeletedTask(taskId); + assertTrue(adTaskCacheManager.hasDeletedTask()); + assertEquals(taskId, adTaskCacheManager.pollDeletedTask()); + assertFalse(adTaskCacheManager.hasDeletedTask()); } public void testAcquireTaskUpdatingSemaphore() throws IOException, InterruptedException { @@ -434,7 +434,7 @@ private List addHCDetectorCache() throws IOException { ADTask adDetectorTask = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detectorId, @@ -444,7 +444,7 @@ private List addHCDetectorCache() throws IOException { ADTask adEntityTask = TestHelpers .randomAdTask( randomAlphaOfLength(5), - ADTaskState.CREATED, + TaskState.CREATED, Instant.now(), null, detectorId, @@ -527,7 +527,7 @@ public void testTaskLanes() throws IOException { public void testRefreshRealtimeJobRunTime() throws InterruptedException { String detectorId = randomAlphaOfLength(5); adTaskCacheManager.initRealtimeTaskCache(detectorId, 1_000); - ADRealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); + RealtimeTaskCache realtimeTaskCache = adTaskCacheManager.getRealtimeTaskCache(detectorId); assertFalse(realtimeTaskCache.expired()); Thread.sleep(3_000); assertTrue(realtimeTaskCache.expired()); @@ -537,7 +537,7 @@ public void testRefreshRealtimeJobRunTime() throws InterruptedException { public void testAddDeletedDetector() { String detectorId = randomAlphaOfLength(5); - adTaskCacheManager.addDeletedDetector(detectorId); + adTaskCacheManager.addDeletedConfig(detectorId); String polledDetectorId = adTaskCacheManager.pollDeletedDetector(); assertEquals(detectorId, polledDetectorId); assertNull(adTaskCacheManager.pollDeletedDetector()); @@ -621,11 +621,11 @@ public void testADHCBatchTaskRunStateCacheWithCancel() { ADHCBatchTaskRunState state = adTaskCacheManager.getOrCreateHCDetectorTaskStateCache(detectorId, detectorTaskId); assertTrue(adTaskCacheManager.detectorTaskStateExists(detectorId, detectorTaskId)); - assertEquals(ADTaskState.INIT.name(), state.getDetectorTaskState()); + assertEquals(TaskState.INIT.name(), state.getDetectorTaskState()); assertFalse(state.expired()); - state.setDetectorTaskState(ADTaskState.RUNNING.name()); - assertEquals(ADTaskState.RUNNING.name(), adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); + state.setDetectorTaskState(TaskState.RUNNING.name()); + assertEquals(TaskState.RUNNING.name(), adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); String cancelReason = randomAlphaOfLength(5); String cancelledBy = randomAlphaOfLength(5); @@ -647,7 +647,7 @@ public void testADHCBatchTaskRunStateCacheWithCancel() { public void testUpdateDetectorTaskState() { String detectorId = randomAlphaOfLength(5); String detectorTaskId = randomAlphaOfLength(5); - String newState = ADTaskState.RUNNING.name(); + String newState = TaskState.RUNNING.name(); adTaskCacheManager.updateDetectorTaskState(detectorId, detectorTaskId, newState); assertEquals(newState, adTaskCacheManager.getDetectorTaskState(detectorId, detectorTaskId)); diff --git a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java index f1b67e71e..621d3015c 100644 --- a/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java +++ b/src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java @@ -30,12 +30,12 @@ import static org.mockito.Mockito.when; import static org.opensearch.ad.constant.ADCommonName.ANOMALY_RESULT_INDEX_ALIAS; import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_REQUEST_TIMEOUT; import static org.opensearch.ad.settings.AnomalyDetectorSettings.BATCH_TASK_PIECE_INTERVAL_SECONDS; import static org.opensearch.ad.settings.AnomalyDetectorSettings.DELETE_AD_RESULT_WHEN_DELETE_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_BATCH_TASK_PER_NODE; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_OLD_AD_TASK_DOCS_PER_DETECTOR; import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.REQUEST_TIMEOUT; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.TestHelpers.randomAdTask; import static org.opensearch.timeseries.TestHelpers.randomAnomalyDetector; @@ -45,7 +45,6 @@ import static org.opensearch.timeseries.TestHelpers.randomIntervalSchedule; import static org.opensearch.timeseries.TestHelpers.randomIntervalTimeConfiguration; import static org.opensearch.timeseries.TestHelpers.randomUser; -import static org.opensearch.timeseries.constant.CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED; import static org.opensearch.timeseries.model.Entity.createSingleAttributeEntity; import java.io.IOException; @@ -81,23 +80,16 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.update.UpdateResponse; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskAction; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; -import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler; -import org.opensearch.ad.stats.InternalStatNames; import org.opensearch.ad.transport.ADStatsNodeResponse; import org.opensearch.ad.transport.ADStatsNodesResponse; import org.opensearch.ad.transport.ADTaskProfileNodeResponse; import org.opensearch.ad.transport.ADTaskProfileResponse; -import org.opensearch.ad.transport.AnomalyDetectorJobResponse; import org.opensearch.ad.transport.ForwardADTaskRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; @@ -129,6 +121,9 @@ import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.DateRange; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.rest.handler.IndexJobActionHandler; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -153,10 +148,10 @@ public class ADTaskManagerTests extends ADUnitTestCase { private TransportService transportService; private ADTaskManager adTaskManager; private ThreadPool threadPool; - private IndexAnomalyDetectorJobActionHandler indexAnomalyDetectorJobActionHandler; + private IndexJobActionHandler indexAnomalyDetectorJobActionHandler; private DateRange detectionDateRange; - private ActionListener listener; + private ActionListener listener; private DiscoveryNode node1; private DiscoveryNode node2; @@ -199,7 +194,7 @@ public class ADTaskManagerTests extends ADUnitTestCase { + ",\"parent_task_id\":\"a1civ3sBwF58XZxvKrko\",\"worker_node\":\"DL5uOJV3TjOOAyh5hJXrCA\",\"current_piece\"" + ":1630999260000,\"execution_end_time\":1630999442814}}"; @Captor - ArgumentCaptor> remoteResponseHandler; + ArgumentCaptor> remoteResponseHandler; @Override public void setUp() throws Exception { @@ -213,14 +208,14 @@ public void setUp() throws Exception { .builder() .put(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.getKey(), 2) .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) - .put(REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) + .put(AD_REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) .build(); clusterSettings = clusterSetting( settings, MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, BATCH_TASK_PIECE_INTERVAL_SECONDS, - REQUEST_TIMEOUT, + AD_REQUEST_TIMEOUT, DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, MAX_BATCH_TASK_PER_NODE, MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS @@ -239,7 +234,7 @@ public void setUp() throws Exception { threadContext = new ThreadContext(settings); when(threadPool.getThreadContext()).thenReturn(threadContext); when(client.threadPool()).thenReturn(threadPool); - indexAnomalyDetectorJobActionHandler = mock(IndexAnomalyDetectorJobActionHandler.class); + indexAnomalyDetectorJobActionHandler = mock(IndexJobActionHandler.class); adTaskManager = spy( new ADTaskManager( settings, @@ -254,9 +249,9 @@ public void setUp() throws Exception { ) ); - listener = spy(new ActionListener() { + listener = spy(new ActionListener() { @Override - public void onResponse(AnomalyDetectorJobResponse bulkItemResponses) {} + public void onResponse(JobResponse bulkItemResponses) {} @Override public void onFailure(Exception e) {} @@ -312,7 +307,7 @@ private void setupHashRingWithSameLocalADVersionNodes() { Consumer function = invocation.getArgument(0); function.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getNodesWithSameLocalAdVersion(any(), any()); + }).when(hashRing).getNodesWithSameLocalVersion(any(), any()); } private void setupHashRingWithOwningNode() { @@ -320,7 +315,7 @@ private void setupHashRingWithOwningNode() { Consumer> function = invocation.getArgument(1); function.accept(Optional.of(node1)); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); } public void testCreateTaskIndexNotAcknowledged() throws IOException { @@ -335,7 +330,7 @@ public void testCreateTaskIndexNotAcknowledged() throws IOException { adTaskManager.startDetector(detector, detectionDateRange, randomUser(), transportService, listener); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - String error = String.format(Locale.ROOT, CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); + String error = String.format(Locale.ROOT, CommonMessages.CREATE_INDEX_NOT_ACKNOWLEDGED, DETECTION_STATE_INDEX); assertEquals(error, exceptionCaptor.getValue().getMessage()); } @@ -450,7 +445,7 @@ private void setupTaskSlots(int node1UsedTaskSlots, int node1AssignedTaskSLots, public void testCheckTaskSlotsWithNoAvailableTaskSlots() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -475,7 +470,7 @@ private void setupSearchTopEntities(int entitySize) { public void testCheckTaskSlotsWithAvailableTaskSlotsForHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -494,7 +489,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForHC() throws IOException { public void testCheckTaskSlotsWithAvailableTaskSlotsForSingleEntityDetector() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of()) @@ -512,7 +507,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForSingleEntityDetector() th public void testCheckTaskSlotsWithAvailableTaskSlotsAndNoEntity() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -530,7 +525,7 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsAndNoEntity() throws IOExcep public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -555,14 +550,14 @@ public void testCheckTaskSlotsWithAvailableTaskSlotsForScale() throws IOExceptio public void testDeleteDuplicateTasks() throws IOException { ADTask adTask = randomAdTask(); - adTaskManager.handleADTaskException(adTask, new DuplicateTaskException("test")); + adTaskManager.handleTaskException(adTask, new DuplicateTaskException("test")); verify(client, times(1)).delete(any(), any()); } public void testParseEntityForSingleCategoryHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers.randomAnomalyDetectorUsingCategoryFields(randomAlphaOfLength(5), ImmutableList.of(randomAlphaOfLength(5))) @@ -575,7 +570,7 @@ public void testParseEntityForSingleCategoryHC() throws IOException { public void testParseEntityForMultiCategoryHC() throws IOException { ADTask adTask = randomAdTask( randomAlphaOfLength(5), - ADTaskState.INIT, + TaskState.INIT, Instant.now(), randomAlphaOfLength(5), TestHelpers @@ -594,7 +589,7 @@ public void testDetectorTaskSlotScaleUpDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities - 1; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -607,7 +602,7 @@ public void testDetectorTaskSlotScaleDownDelta() { DiscoveryNode[] eligibleDataNodes = new DiscoveryNode[] { node1, node2 }; // Scale down - when(hashRing.getNodesWithSameLocalAdVersion()).thenReturn(eligibleDataNodes); + when(hashRing.getNodesWithSameLocalVersion()).thenReturn(eligibleDataNodes); when(adTaskCacheManager.getUnfinishedEntityCount(detectorId)).thenReturn(maxRunningEntities * 10); int taskSlots = maxRunningEntities * 5; when(adTaskCacheManager.getDetectorTaskSlots(detectorId)).thenReturn(taskSlots); @@ -715,7 +710,7 @@ public void testGetADTaskWithExistingTask() { @SuppressWarnings("unchecked") public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { String detectorId = randomAlphaOfLength(5); - String state = ADTaskState.RUNNING.name(); + String state = TaskState.RUNNING.name(); Long rcfTotalUpdates = randomLongBetween(200, 1000); Long detectorIntervalInMinutes = 1L; String error = randomAlphaOfLength(5); @@ -726,7 +721,7 @@ public void testUpdateLatestRealtimeTaskOnCoordinatingNode() { ActionListener listener = invocation.getArgument(3); listener.onResponse(new UpdateResponse(ShardId.fromString("[test][1]"), "1", 0L, 1L, 1L, DocWriteResponse.Result.UPDATED)); return null; - }).when(adTaskManager).updateLatestADTask(anyString(), any(), anyMap(), any()); + }).when(adTaskManager).updateLatestTask(anyString(), any(), anyMap(), any()); adTaskManager .updateLatestRealtimeTaskOnCoordinatingNode( detectorId, @@ -774,7 +769,7 @@ public void testGetLocalADTaskProfilesByDetectorId() { @SuppressWarnings("unchecked") public void testRemoveStaleRunningEntity() throws IOException { - ActionListener actionListener = mock(ActionListener.class); + ActionListener actionListener = mock(ActionListener.class); ADTask adTask = randomAdTask(); String entity = randomAlphaOfLength(5); ExecutorService executeService = mock(ExecutorService.class); @@ -847,7 +842,7 @@ public void testCleanADResultOfDeletedDetectorWithException() { .builder() .put(MAX_OLD_AD_TASK_DOCS_PER_DETECTOR.getKey(), 2) .put(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 1) - .put(REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) + .put(AD_REQUEST_TIMEOUT.getKey(), TimeValue.timeValueSeconds(10)) .put(DELETE_AD_RESULT_WHEN_DELETE_DETECTOR.getKey(), true) .build(); @@ -855,7 +850,7 @@ public void testCleanADResultOfDeletedDetectorWithException() { settings, MAX_OLD_AD_TASK_DOCS_PER_DETECTOR, BATCH_TASK_PIECE_INTERVAL_SECONDS, - REQUEST_TIMEOUT, + AD_REQUEST_TIMEOUT, DELETE_AD_RESULT_WHEN_DELETE_DETECTOR, MAX_BATCH_TASK_PER_NODE, MAX_RUNNING_ENTITIES_PER_DETECTOR_FOR_HISTORICAL_ANALYSIS @@ -878,28 +873,28 @@ public void testCleanADResultOfDeletedDetectorWithException() { ); adTaskManager.cleanADResultOfDeletedDetector(); verify(client, times(1)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); - verify(adTaskCacheManager, times(1)).addDeletedDetector(eq(detectorId)); + verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); adTaskManager.cleanADResultOfDeletedDetector(); verify(client, times(2)).execute(eq(DeleteByQueryAction.INSTANCE), any(), any()); - verify(adTaskCacheManager, times(1)).addDeletedDetector(eq(detectorId)); + verify(adTaskCacheManager, times(1)).addDeletedConfig(eq(detectorId)); } public void testMaintainRunningHistoricalTasksWithOwningNodeIsNotLocalNode() { // Test no owning node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.empty()); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); // Test owning node is not local node - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node2)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node2)); doReturn(node1).when(clusterService).localNode(); adTaskManager.maintainRunningHistoricalTasks(transportService, 10); verify(client, never()).search(any(), any()); } public void testMaintainRunningHistoricalTasksWithNoRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { @@ -932,7 +927,7 @@ public void testMaintainRunningHistoricalTasksWithNoRunningTask() { } public void testMaintainRunningHistoricalTasksWithRunningTask() { - when(hashRing.getOwningNodeWithHighestAdVersion(anyString())).thenReturn(Optional.of(node1)); + when(hashRing.getOwningNodeWithHighestVersion(anyString())).thenReturn(Optional.of(node1)); doReturn(node1).when(clusterService).localNode(); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); @@ -987,11 +982,11 @@ public void testMaintainRunningRealtimeTasks() { when(adTaskCacheManager.getDetectorIdsInRealtimeTaskCache()).thenReturn(new String[] { detectorId1, detectorId2, detectorId3 }); when(adTaskCacheManager.getRealtimeTaskCache(detectorId1)).thenReturn(null); - ADRealtimeTaskCache cacheOfDetector2 = mock(ADRealtimeTaskCache.class); + RealtimeTaskCache cacheOfDetector2 = mock(RealtimeTaskCache.class); when(cacheOfDetector2.expired()).thenReturn(false); when(adTaskCacheManager.getRealtimeTaskCache(detectorId2)).thenReturn(cacheOfDetector2); - ADRealtimeTaskCache cacheOfDetector3 = mock(ADRealtimeTaskCache.class); + RealtimeTaskCache cacheOfDetector3 = mock(RealtimeTaskCache.class); when(cacheOfDetector3.expired()).thenReturn(true); when(adTaskCacheManager.getRealtimeTaskCache(detectorId3)).thenReturn(cacheOfDetector3); @@ -1005,12 +1000,12 @@ public void testStartHistoricalAnalysisWithNoOwningNode() throws IOException { DateRange detectionDateRange = TestHelpers.randomDetectionDateRange(); User user = null; int availableTaskSlots = randomIntBetween(1, 10); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(anyString(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(anyString(), any(), any()); adTaskManager.startHistoricalAnalysis(detector, detectionDateRange, user, availableTaskSlots, transportService, listener); verify(listener, times(1)).onFailure(any()); } @@ -1034,7 +1029,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopp .detectorId(randomAlphaOfLength(5)) .detector(detector) .entity(null) - .state(ADTaskState.RUNNING.name()) + .state(TaskState.RUNNING.name()) .taskProgress(0.5f) .initProgress(1.0f) .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) @@ -1067,7 +1062,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningRealtimeTaskWithTaskStopp ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1100,7 +1095,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws I .detectorId(randomAlphaOfLength(5)) .detector(detector) .entity(null) - .state(ADTaskState.RUNNING.name()) + .state(TaskState.RUNNING.name()) .taskProgress(0.5f) .initProgress(1.0f) .currentPiece(Instant.now().truncatedTo(ChronoUnit.SECONDS).minus(randomIntBetween(1, 100), ChronoUnit.MINUTES)) @@ -1133,7 +1128,7 @@ public void testGetAndExecuteOnLatestADTasksWithRunningHistoricalTask() throws I ); setupGetAndExecuteOnLatestADTasks(profile); adTaskManager - .getAndExecuteOnLatestADTasks( + .getAndExecuteOnLatestTasks( detectorId, null, null, @@ -1186,13 +1181,13 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { }).when(client).search(any(), any()); String detectorId = randomAlphaOfLength(5); Consumer> function = mock(Consumer.class); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer getNodeFunction = invocation.getArgument(0); getNodeFunction.accept(new DiscoveryNode[] { node1, node2 }); return null; - }).when(hashRing).getAllEligibleDataNodesWithKnownAdVersion(any(), any()); + }).when(hashRing).getAllEligibleDataNodesWithKnownVersion(any(), any()); doAnswer(invocation -> { ActionListener taskProfileResponseListener = invocation.getArgument(2); @@ -1237,7 +1232,7 @@ private void setupGetAndExecuteOnLatestADTasks(ADTaskProfile adTaskProfile) { true, BytesReference .bytes( - new AnomalyDetectorJob( + new Job( detectorId, randomIntervalSchedule(), randomIntervalTimeConfiguration(), @@ -1266,7 +1261,7 @@ public void testCreateADTaskDirectlyWithException() throws IOException { ActionListener listener = mock(ActionListener.class); doThrow(new RuntimeException("test")).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(1)).onFailure(any()); doAnswer(invocation -> { @@ -1274,19 +1269,19 @@ public void testCreateADTaskDirectlyWithException() throws IOException { actionListener.onFailure(new RuntimeException("test")); return null; }).when(client).index(any(), any()); - adTaskManager.createADTaskDirectly(adTask, function, listener); + adTaskManager.createTaskDirectly(adTask, function, listener); verify(listener, times(2)).onFailure(any()); } public void testCleanChildTasksAndADResultsOfDeletedTaskWithNoDeletedDetectorTask() { - when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(false); + when(adTaskCacheManager.hasDeletedTask()).thenReturn(false); adTaskManager.cleanChildTasksAndADResultsOfDeletedTask(); verify(client, never()).execute(any(), any(), any()); } public void testCleanChildTasksAndADResultsOfDeletedTaskWithNullTask() { - when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); - when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(null); + when(adTaskCacheManager.hasDeletedTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedTask()).thenReturn(null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new RuntimeException("test")); @@ -1304,8 +1299,8 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithNullTask() { } public void testCleanChildTasksAndADResultsOfDeletedTaskWithFailToDeleteADResult() { - when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); - when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(randomAlphaOfLength(5)); + when(adTaskCacheManager.hasDeletedTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedTask()).thenReturn(randomAlphaOfLength(5)); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); actionListener.onFailure(new RuntimeException("test")); @@ -1323,8 +1318,8 @@ public void testCleanChildTasksAndADResultsOfDeletedTaskWithFailToDeleteADResult } public void testCleanChildTasksAndADResultsOfDeletedTask() { - when(adTaskCacheManager.hasDeletedDetectorTask()).thenReturn(true); - when(adTaskCacheManager.pollDeletedDetectorTask()).thenReturn(randomAlphaOfLength(5)).thenReturn(null); + when(adTaskCacheManager.hasDeletedTask()).thenReturn(true); + when(adTaskCacheManager.pollDeletedTask()).thenReturn(randomAlphaOfLength(5)).thenReturn(null); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); BulkByScrollResponse response = mock(BulkByScrollResponse.class); @@ -1412,7 +1407,7 @@ public void testDeleteADTasksWithException() { @SuppressWarnings("unchecked") public void testScaleUpTaskSlots() throws IOException { ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); when(adTaskCacheManager.getAvailableNewEntityTaskLanes(anyString())).thenReturn(0); doReturn(2).when(adTaskManager).detectorTaskSlotScaleDelta(anyString()); when(adTaskCacheManager.getLastScaleEntityTaskLaneTime(anyString())).thenReturn(null); @@ -1432,12 +1427,12 @@ public void testScaleUpTaskSlots() throws IOException { public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException { ADTask adTask = randomAdTask(ADTaskType.HISTORICAL_HC_ENTITY); ForwardADTaskRequest forwardADTaskRequest = new ForwardADTaskRequest(adTask, ADTaskAction.APPLY_FOR_TASK_SLOTS); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); return null; - }).when(hashRing).buildAndGetOwningNodeWithSameLocalAdVersion(any(), any(), any()); + }).when(hashRing).buildAndGetOwningNodeWithSameLocalVersion(any(), any(), any()); adTaskManager.forwardRequestToLeadNode(forwardADTaskRequest, transportService, listener); verify(listener, times(1)).onFailure(any()); @@ -1448,7 +1443,7 @@ public void testScaleTaskLaneOnCoordinatingNode() { ADTask adTask = mock(ADTask.class); when(adTask.getCoordinatingNode()).thenReturn(node1.getId()); when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 }); - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, transportService, listener); } @@ -1457,7 +1452,7 @@ public void testStartDetectorWithException() throws IOException { AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); DateRange detectionDateRange = randomDetectionDateRange(); User user = null; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); when(detectionIndices.doesStateIndexExist()).thenReturn(false); doThrow(new RuntimeException("test")).when(detectionIndices).initStateIndex(any()); adTaskManager.startDetector(detector, detectionDateRange, user, transportService, listener); @@ -1468,7 +1463,7 @@ public void testStartDetectorWithException() throws IOException { public void testStopDetectorWithNonExistingDetector() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); function.accept(Optional.empty()); @@ -1482,7 +1477,7 @@ public void testStopDetectorWithNonExistingDetector() { public void testStopDetectorWithNonExistingTask() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); @@ -1504,7 +1499,7 @@ public void testStopDetectorWithNonExistingTask() { public void testStopDetectorWithTaskDone() { String detectorId = randomAlphaOfLength(5); boolean historical = true; - ActionListener listener = mock(ActionListener.class); + ActionListener listener = mock(ActionListener.class); doAnswer(invocation -> { Consumer> function = invocation.getArgument(1); AnomalyDetector detector = randomAnomalyDetector(ImmutableList.of(randomFeature(true))); @@ -1621,7 +1616,7 @@ public void testDeleteTaskDocs() { ExecutorFunction function = mock(ExecutorFunction.class); ActionListener listener = mock(ActionListener.class); adTaskManager.deleteTaskDocs(detectorId, searchRequest, function, listener); - verify(adTaskCacheManager, times(1)).addDeletedDetectorTask(anyString()); + verify(adTaskCacheManager, times(1)).addDeletedTask(anyString()); verify(function, times(1)).execute(); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java index 1e3a3506e..60f68131e 100644 --- a/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADBatchAnomalyResultTransportActionTests.java @@ -28,13 +28,14 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.common.settings.Settings; import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.TimeSeriesTask; +import org.opensearch.timeseries.util.ExceptionUtil; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -102,7 +103,7 @@ public void testHistoricalAnalysisWithValidDateRange() throws IOException, Inter client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(20000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); } public void testHistoricalAnalysisWithNonExistingIndex() throws IOException { @@ -140,7 +141,7 @@ public void testDisableADPlugin() throws IOException { ImmutableList.of(NotSerializableExceptionWrapper.class, EndRunException.class), () -> client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage(), exception.getMessage().contains("AD functionality is disabled")); + assertTrue(exception.getMessage(), exception.getMessage().contains("AD plugin is disabled")); updateTransientSettings(ImmutableMap.of(AD_ENABLED, false)); } finally { // guarantee reset back to default @@ -162,7 +163,7 @@ public void testMultipleTasks() throws IOException, InterruptedException { client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(25000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(ADTask.STATE_FIELD))); + assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(doc.getSourceAsMap().get(TimeSeriesTask.STATE_FIELD))); updateTransientSettings(ImmutableMap.of(MAX_BATCH_TASK_PER_NODE.getKey(), 1)); } @@ -187,6 +188,6 @@ private void testInvalidDetectionDateRange(DateRange dateRange, String error) th client().execute(ADBatchAnomalyResultAction.INSTANCE, request).actionGet(5000); Thread.sleep(5000); GetResponse doc = getDoc(ADCommonName.DETECTION_STATE_INDEX, request.getAdTask().getTaskId()); - assertEquals(error, doc.getSourceAsMap().get(ADTask.ERROR_FIELD)); + assertEquals(error, doc.getSourceAsMap().get(TimeSeriesTask.ERROR_FIELD)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java index 6946953fc..076c8763c 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkResponseTests.java @@ -20,16 +20,17 @@ import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class ADResultBulkResponseTests extends OpenSearchTestCase { public void testSerialization() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); List retryRequests = new ArrayList<>(); retryRequests.add(new IndexRequest("index").id("blah").source(Collections.singletonMap("foo", "bar"))); - ADResultBulkResponse response = new ADResultBulkResponse(retryRequests); + ResultBulkResponse response = new ResultBulkResponse(retryRequests); response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - ADResultBulkResponse readResponse = new ADResultBulkResponse(streamInput); + ResultBulkResponse readResponse = new ResultBulkResponse(streamInput); assertTrue(readResponse.hasFailures()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java index 432849b82..a5d855f19 100644 --- a/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADResultBulkTransportActionTests.java @@ -41,6 +41,8 @@ import org.opensearch.index.IndexingPressure; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.transport.ResultBulkResponse; import org.opensearch.transport.TransportService; public class ADResultBulkTransportActionTests extends AbstractTimeSeriesTest { @@ -118,7 +120,7 @@ public void testSendAll() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -151,7 +153,7 @@ public void testSendPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -190,7 +192,7 @@ public void testSendRandomPartial() { return null; }).when(client).execute(any(), any(), any()); - PlainActionFuture future = PlainActionFuture.newFuture(); + PlainActionFuture future = PlainActionFuture.newFuture(); resultBulk.doExecute(null, originalRequest, future); future.actionGet(); @@ -210,6 +212,6 @@ public void testSerialzationRequest() throws IOException { public void testValidateRequest() { ActionRequestValidationException e = new ADResultBulkRequest().validate(); - assertThat(e.validationErrors(), hasItem(ADResultBulkRequest.NO_REQUESTS_ADDED_ERR)); + assertThat(e.validationErrors(), hasItem(CommonMessages.NO_REQUESTS_ADDED_ERR)); } } diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java index 95799f911..23576d1f6 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsNodesTransportActionTests.java @@ -13,7 +13,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE; import java.time.Clock; import java.util.Arrays; @@ -26,19 +26,9 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.stats.ADStat; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.InternalStatNames; -import org.opensearch.ad.stats.suppliers.CounterSupplier; -import org.opensearch.ad.stats.suppliers.IndexStatusSupplier; -import org.opensearch.ad.stats.suppliers.ModelsOnNodeSupplier; -import org.opensearch.ad.stats.suppliers.SettableSupplier; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -49,13 +39,23 @@ import org.opensearch.monitor.jvm.JvmStats; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.stats.InternalStatNames; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.stats.suppliers.IndexStatusSupplier; +import org.opensearch.timeseries.stats.suppliers.ModelsOnNodeSupplier; +import org.opensearch.timeseries.stats.suppliers.SettableSupplier; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import org.opensearch.transport.TransportService; public class ADStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private ADStatsNodesTransportAction action; private ADStats adStats; - private Map> statsMap; + private Map> statsMap; private String clusterStatName1, clusterStatName2; private String nodeStatName1, nodeStatName2; private ADTaskManager adTaskManager; @@ -76,8 +76,8 @@ public void setUp() throws Exception { clusterService(), indexNameResolver ); - ModelManager modelManager = mock(ModelManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); + ADModelManager modelManager = mock(ADModelManager.class); + HCCacheProvider cacheProvider = mock(HCCacheProvider.class); EntityCache cache = mock(EntityCache.class); when(cacheProvider.get()).thenReturn(cache); @@ -86,21 +86,24 @@ public void setUp() throws Exception { nodeStatName1 = "nodeStat1"; nodeStatName2 = "nodeStat2"; - Settings settings = Settings.builder().put(MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); + Settings settings = Settings.builder().put(AD_MAX_MODEL_SIZE_PER_NODE.getKey(), 10).build(); ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(MAX_MODEL_SIZE_PER_NODE))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AD_MAX_MODEL_SIZE_PER_NODE))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - statsMap = new HashMap>() { + statsMap = new HashMap>() { { - put(nodeStatName1, new ADStat<>(false, new CounterSupplier())); - put(nodeStatName2, new ADStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService))); - put(clusterStatName1, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); - put(clusterStatName2, new ADStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); - put(InternalStatNames.JVM_HEAP_USAGE.getName(), new ADStat<>(true, new SettableSupplier())); + put(nodeStatName1, new TimeSeriesStat<>(false, new CounterSupplier())); + put( + nodeStatName2, + new TimeSeriesStat<>(false, new ModelsOnNodeSupplier(modelManager, cacheProvider, settings, clusterService)) + ); + put(clusterStatName1, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index1"))); + put(clusterStatName2, new TimeSeriesStat<>(true, new IndexStatusSupplier(indexUtils, "index2"))); + put(InternalStatNames.JVM_HEAP_USAGE.getName(), new TimeSeriesStat<>(true, new SettableSupplier())); } }; diff --git a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java index b9595e2e7..c8f23c506 100644 --- a/src/test/java/org/opensearch/ad/transport/ADStatsTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADStatsTests.java @@ -34,8 +34,6 @@ import org.opensearch.Version; import org.opensearch.action.FailedNodeException; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelState; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; @@ -51,6 +49,7 @@ import test.org.opensearch.ad.util.JsonDeserializer; +import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; import com.google.gson.JsonArray; import com.google.gson.JsonElement; @@ -140,17 +139,18 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF attributes.put(name2, val2); String detectorId = "detectorId"; Entity entity = Entity.createEntityFromOrderedMap(attributes); - EntityModel entityModel = new EntityModel(entity, null, null); + EntityModel entityModel = new EntityModel<>(entity, null, null); Clock clock = mock(Clock.class); when(clock.instant()).thenReturn(Instant.now()); - ModelState state = new ModelState( - entityModel, - entity.getModelId(detectorId).get(), - detectorId, - "entity", - clock, - 0.1f - ); + ADModelState> state = + new ADModelState>( + entityModel, + entity.getModelId(detectorId).get(), + detectorId, + "entity", + clock, + 0.1f + ); Map stats = state.getModelStateAsMap(); // Test serialization @@ -167,7 +167,7 @@ public void testADStatsNodeResponseWithEntity() throws IOException, JsonPathNotF String json = Strings.toString(builder); for (Map.Entry stat : stats.entrySet()) { - if (stat.getKey().equals(ModelState.LAST_CHECKPOINT_TIME_KEY) || stat.getKey().equals(ModelState.LAST_USED_TIME_KEY)) { + if (stat.getKey().equals(ADModelState.LAST_CHECKPOINT_TIME_KEY) || stat.getKey().equals(ADModelState.LAST_USED_TIME_KEY)) { assertEquals("toXContent does not work", JsonDeserializer.getLongValue(json, stat.getKey()), stat.getValue()); } else if (stat.getKey().equals(CommonName.ENTITY_KEY)) { JsonArray array = JsonDeserializer.getArrayValue(json, stat.getKey()); diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java index 30931af13..c295558bd 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobActionTests.java @@ -42,13 +42,15 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; public class AnomalyDetectorJobActionTests extends OpenSearchIntegTestCase { private AnomalyDetectorJobTransportAction action; private Task task; - private AnomalyDetectorJobRequest request; - private ActionListener response; + private JobRequest request; + private ActionListener response; @Override @Before @@ -57,7 +59,7 @@ public void setUp() throws Exception { ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); Settings build = Settings.builder().build(); @@ -81,10 +83,10 @@ public void setUp() throws Exception { mock(ExecuteADResultResponseRecorder.class) ); task = mock(Task.class); - request = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_start"); - response = new ActionListener() { + request = new JobRequest("1234", 4567, 7890, "_start"); + response = new ActionListener() { @Override - public void onResponse(AnomalyDetectorJobResponse adResponse) { + public void onResponse(JobResponse adResponse) { // Will not be called as there is no detector Assert.assertTrue(false); } @@ -104,7 +106,7 @@ public void testStartAdJobTransportAction() { @Test public void testStopAdJobTransportAction() { - AnomalyDetectorJobRequest stopRequest = new AnomalyDetectorJobRequest("1234", 4567, 7890, "_stop"); + JobRequest stopRequest = new JobRequest("1234", 4567, 7890, "_stop"); action.doExecute(task, stopRequest, response); } @@ -117,13 +119,13 @@ public void testAdJobAction() { @Test public void testAdJobRequest() throws IOException { DateRange detectionDateRange = new DateRange(Instant.MIN, Instant.now()); - request = new AnomalyDetectorJobRequest("1234", detectionDateRange, false, 4567, 7890, "_start"); + request = new JobRequest("1234", detectionDateRange, false, 4567, 7890, "_start"); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test @@ -131,17 +133,17 @@ public void testAdJobRequest_NullDetectionDateRange() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobRequest newRequest = new AnomalyDetectorJobRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + JobRequest newRequest = new JobRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @Test public void testAdJobResponse() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); - AnomalyDetectorJobResponse response = new AnomalyDetectorJobResponse("1234", 45, 67, 890, RestStatus.OK); + JobResponse response = new JobResponse("1234", 45, 67, 890, RestStatus.OK); response.writeTo(out); StreamInput input = out.bytes().streamInput(); - AnomalyDetectorJobResponse newResponse = new AnomalyDetectorJobResponse(input); + JobResponse newResponse = new JobResponse(input); Assert.assertEquals(response.getId(), newResponse.getId()); } } diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java index daf86ab7c..acbe0ac0d 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyDetectorJobTransportActionTests.java @@ -18,7 +18,6 @@ import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; import static org.opensearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; import static org.opensearch.timeseries.TestHelpers.HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS; -import static org.opensearch.timeseries.constant.CommonMessages.FAIL_TO_FIND_CONFIG_MSG; import static org.opensearch.timeseries.util.RestHandlerUtils.PROFILE; import static org.opensearch.timeseries.util.RestHandlerUtils.START_JOB; import static org.opensearch.timeseries.util.RestHandlerUtils.STOP_JOB; @@ -46,10 +45,8 @@ import org.opensearch.ad.mock.transport.MockAnomalyDetectorJobAction; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskProfile; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; import org.opensearch.common.settings.Settings; @@ -58,7 +55,11 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.DateRange; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.transport.GetConfigRequest; +import org.opensearch.timeseries.transport.JobRequest; +import org.opensearch.timeseries.transport.JobResponse; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -97,7 +98,7 @@ protected Settings nodeSettings(int nodeOrdinal) { public void testDetectorIndexNotFound() { deleteDetectorIndex(); String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); IndexNotFoundException exception = expectThrows( IndexNotFoundException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(3000) @@ -107,12 +108,12 @@ public void testDetectorIndexNotFound() { public void testDetectorNotFound() { String detectorId = randomAlphaOfLength(5); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) ); - assertTrue(exception.getMessage().contains(FAIL_TO_FIND_CONFIG_MSG)); + assertTrue(exception.getMessage().contains(CommonMessages.FAIL_TO_FIND_CONFIG_MSG)); } public void testValidHistoricalAnalysis() throws IOException, InterruptedException { @@ -126,17 +127,10 @@ public void testStartHistoricalAnalysisWithUser() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); ADTask adTask = getADTask(response.getId()); assertNotNull(adTask.getStartedBy()); assertNotNull(adTask.getUser()); @@ -155,18 +149,11 @@ public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOEx ImmutableList.of(categoryField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100000); waitUntil(() -> { try { ADTask task = getADTask(response.getId()); @@ -180,7 +167,7 @@ public void testStartHistoricalAnalysisForSingleCategoryHCWithUser() throws IOEx assertTrue(HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS.contains(adTask.getState())); assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); - if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + if (TaskState.FINISHED.name().equals(adTask.getState())) { List adTasks = searchADTasks(detectorId, true, 100); assertEquals(4, adTasks.size()); List entityTasks = adTasks @@ -207,18 +194,11 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc ImmutableList.of(categoryField, ipField) ); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); Client nodeClient = getDataNodeClient(); if (nodeClient != null) { - AnomalyDetectorJobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100_000); + JobResponse response = nodeClient.execute(MockAnomalyDetectorJobAction.INSTANCE, request).actionGet(100_000); String taskId = response.getId(); waitUntil(() -> { @@ -236,7 +216,7 @@ public void testStartHistoricalAnalysisForMultiCategoryHCWithUser() throws IOExc assertEquals(categoryField, adTask.getDetector().getCategoryFields().get(0)); assertEquals(ipField, adTask.getDetector().getCategoryFields().get(1)); - if (ADTaskState.FINISHED.name().equals(adTask.getState())) { + if (TaskState.FINISHED.name().equals(adTask.getState())) { List adTasks = searchADTasks(detectorId, taskId, true, 100); assertEquals(5, adTasks.size()); List entityTasks = adTasks @@ -252,8 +232,8 @@ public void testRunMultipleTasksForHistoricalAnalysis() throws IOException, Inte AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); assertNotNull(response.getId()); OpenSearchStatusException exception = null; // Add retry to solve the flaky test @@ -282,14 +262,7 @@ public void testRaceConditionByStartingMultipleTasks() throws IOException, Inter AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - UNASSIGNED_SEQ_NO, - UNASSIGNED_PRIMARY_TERM, - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); client().execute(AnomalyDetectorJobAction.INSTANCE, request); client().execute(AnomalyDetectorJobAction.INSTANCE, request); @@ -298,7 +271,7 @@ public void testRaceConditionByStartingMultipleTasks() throws IOException, Inter assertEquals(1, adTasks.size()); assertTrue(adTasks.get(0).getLatest()); - assertNotEquals(ADTaskState.FAILED.name(), adTasks.get(0).getState()); + assertNotEquals(TaskState.FAILED.name(), adTasks.get(0).getState()); } // TODO: fix this flaky test case @@ -309,24 +282,17 @@ public void testCleanOldTaskDocs() throws InterruptedException, IOException { String detectorId = createDetector(detector); createDetectionStateIndex(); - List states = ImmutableList.of(ADTaskState.FAILED, ADTaskState.FINISHED, ADTaskState.STOPPED); - for (ADTaskState state : states) { + List states = ImmutableList.of(TaskState.FAILED, TaskState.FINISHED, TaskState.STOPPED); + for (TaskState state : states) { ADTask task = randomADTask(randomAlphaOfLength(5), detector, detectorId, dateRange, state); createADTask(task); } long count = countDocs(ADCommonName.DETECTION_STATE_INDEX); assertEquals(states.size(), count); - AnomalyDetectorJobRequest request = new AnomalyDetectorJobRequest( - detectorId, - dateRange, - true, - randomLong(), - randomLong(), - START_JOB - ); + JobRequest request = new JobRequest(detectorId, dateRange, true, randomLong(), randomLong(), START_JOB); - AtomicReference response = new AtomicReference<>(); + AtomicReference response = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); Thread.sleep(2000); client().execute(AnomalyDetectorJobAction.INSTANCE, request, ActionListener.wrap(r -> { @@ -354,13 +320,13 @@ public void testStartRealtimeDetector() throws IOException { String detectorId = realtimeResult.get(0); String jobId = realtimeResult.get(1); GetResponse jobDoc = getDoc(CommonName.JOB_INDEX, detectorId); - AnomalyDetectorJob job = toADJob(jobDoc); + Job job = toADJob(jobDoc); assertTrue(job.isEnabled()); assertEquals(detectorId, job.getName()); List adTasks = searchADTasks(detectorId, true, 10); assertEquals(1, adTasks.size()); - assertEquals(ADTaskType.REALTIME_SINGLE_ENTITY.name(), adTasks.get(0).getTaskType()); + assertEquals(ADTaskType.REALTIME_SINGLE_STREAM_DETECTOR.name(), adTasks.get(0).getTaskType()); assertNotEquals(jobId, adTasks.get(0).getTaskId()); } @@ -368,8 +334,8 @@ private List startRealtimeDetector() throws IOException { AnomalyDetector detector = TestHelpers .randomDetector(ImmutableList.of(maxValueFeature()), testIndex, detectionIntervalInMinutes, timeField); String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, null); - AnomalyDetectorJobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); + JobRequest request = startDetectorJobRequest(detectorId, null); + JobResponse response = client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); String jobId = response.getId(); assertEquals(detectorId, jobId); return ImmutableList.of(detectorId, jobId); @@ -399,7 +365,7 @@ public void testHistoricalDetectorWithoutEnabledFeature() throws IOException { private void testInvalidDetector(AnomalyDetector detector, String error) throws IOException { String detectorId = createDetector(detector); - AnomalyDetectorJobRequest request = startDetectorJobRequest(detectorId, dateRange); + JobRequest request = startDetectorJobRequest(detectorId, dateRange); OpenSearchStatusException exception = expectThrows( OpenSearchStatusException.class, () -> client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000) @@ -407,12 +373,12 @@ private void testInvalidDetector(AnomalyDetector detector, String error) throws assertEquals(error, exception.getMessage()); } - private AnomalyDetectorJobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { - return new AnomalyDetectorJobRequest(detectorId, dateRange, false, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); + private JobRequest startDetectorJobRequest(String detectorId, DateRange dateRange) { + return new JobRequest(detectorId, dateRange, false, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, START_JOB); } - private AnomalyDetectorJobRequest stopDetectorJobRequest(String detectorId, boolean historical) { - return new AnomalyDetectorJobRequest(detectorId, null, historical, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, STOP_JOB); + private JobRequest stopDetectorJobRequest(String detectorId, boolean historical) { + return new JobRequest(detectorId, null, historical, UNASSIGNED_SEQ_NO, UNASSIGNED_PRIMARY_TERM, STOP_JOB); } public void testStopRealtimeDetector() throws IOException { @@ -420,24 +386,24 @@ public void testStopRealtimeDetector() throws IOException { String detectorId = realtimeResult.get(0); String jobId = realtimeResult.get(1); - AnomalyDetectorJobRequest request = stopDetectorJobRequest(detectorId, false); + JobRequest request = stopDetectorJobRequest(detectorId, false); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); GetResponse doc = getDoc(CommonName.JOB_INDEX, detectorId); - AnomalyDetectorJob job = toADJob(doc); + Job job = toADJob(doc); assertFalse(job.isEnabled()); assertEquals(detectorId, job.getName()); List adTasks = searchADTasks(detectorId, true, 10); assertEquals(1, adTasks.size()); - assertEquals(ADTaskType.REALTIME_SINGLE_ENTITY.name(), adTasks.get(0).getTaskType()); + assertEquals(ADTaskType.REALTIME_SINGLE_STREAM_DETECTOR.name(), adTasks.get(0).getTaskType()); assertNotEquals(jobId, adTasks.get(0).getTaskId()); - assertEquals(ADTaskState.STOPPED.name(), adTasks.get(0).getState()); + assertEquals(TaskState.STOPPED.name(), adTasks.get(0).getState()); } public void testStopHistoricalDetector() throws IOException, InterruptedException { updateTransientSettings(ImmutableMap.of(BATCH_TASK_PIECE_INTERVAL_SECONDS.getKey(), 5)); ADTask adTask = startHistoricalAnalysis(startTime, endTime); - assertEquals(ADTaskState.INIT.name(), adTask.getState()); + assertEquals(TaskState.INIT.name(), adTask.getState()); assertNull(adTask.getStartedBy()); assertNull(adTask.getUser()); waitUntil(() -> { @@ -447,7 +413,7 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio if (taskRunning) { // It's possible that the task not started on worker node yet. Recancel it to make sure // task cancelled. - AnomalyDetectorJobRequest request = stopDetectorJobRequest(adTask.getId(), true); + JobRequest request = stopDetectorJobRequest(adTask.getId(), true); client().execute(AnomalyDetectorJobAction.INSTANCE, request).actionGet(10000); } return !taskRunning; @@ -456,13 +422,13 @@ public void testStopHistoricalDetector() throws IOException, InterruptedExceptio } }, 20, TimeUnit.SECONDS); ADTask stoppedTask = getADTask(adTask.getTaskId()); - assertEquals(ADTaskState.STOPPED.name(), stoppedTask.getState()); + assertEquals(TaskState.STOPPED.name(), stoppedTask.getState()); assertEquals(0, getExecutingADTask()); } public void testProfileHistoricalDetector() throws IOException, InterruptedException { ADTask adTask = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request = taskProfileRequest(adTask.getId()); + GetConfigRequest request = taskProfileRequest(adTask.getId()); GetAnomalyDetectorResponse response = client().execute(GetAnomalyDetectorAction.INSTANCE, request).actionGet(10000); assertTrue(response.getDetectorProfile().getAdTaskProfile() != null); @@ -488,8 +454,8 @@ public void testProfileWithMultipleRunningTask() throws IOException { ADTask adTask1 = startHistoricalAnalysis(startTime, endTime); ADTask adTask2 = startHistoricalAnalysis(startTime, endTime); - GetAnomalyDetectorRequest request1 = taskProfileRequest(adTask1.getId()); - GetAnomalyDetectorRequest request2 = taskProfileRequest(adTask2.getId()); + GetConfigRequest request1 = taskProfileRequest(adTask1.getId()); + GetConfigRequest request2 = taskProfileRequest(adTask2.getId()); GetAnomalyDetectorResponse response1 = client().execute(GetAnomalyDetectorAction.INSTANCE, request1).actionGet(10000); GetAnomalyDetectorResponse response2 = client().execute(GetAnomalyDetectorAction.INSTANCE, request2).actionGet(10000); ADTaskProfile taskProfile1 = response1.getDetectorProfile().getAdTaskProfile(); @@ -499,8 +465,8 @@ public void testProfileWithMultipleRunningTask() throws IOException { assertNotEquals(taskProfile1.getNodeId(), taskProfile2.getNodeId()); } - private GetAnomalyDetectorRequest taskProfileRequest(String detectorId) throws IOException { - return new GetAnomalyDetectorRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); + private GetConfigRequest taskProfileRequest(String detectorId) throws IOException { + return new GetConfigRequest(detectorId, Versions.MATCH_ANY, false, false, "", PROFILE, true, null); } private long getExecutingADTask() { diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java index 6b671d6e2..02219c398 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTests.java @@ -65,25 +65,14 @@ import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.SinglePointFeatures; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.SingleStreamModelIdMapper; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -108,6 +97,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.breaker.CircuitBreakerService; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.EndRunException; import org.opensearch.timeseries.common.exception.InternalFailure; import org.opensearch.timeseries.common.exception.LimitExceededException; @@ -115,8 +106,16 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.feature.SinglePointFeatures; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.Stats; +import org.opensearch.timeseries.stats.TimeSeriesStat; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; +import org.opensearch.timeseries.transport.ResultResponse; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.NodeNotConnectedException; import org.opensearch.transport.RemoteTransportException; import org.opensearch.transport.Transport; @@ -136,9 +135,9 @@ public class AnomalyResultTests extends AbstractTimeSeriesTest { private Settings settings; private TransportService transportService; private ClusterService clusterService; - private NodeStateManager stateManager; + private ADNodeStateManager stateManager; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private Client client; private SecurityClientUtil clientUtil; private AnomalyDetector detector; @@ -148,8 +147,8 @@ public class AnomalyResultTests extends AbstractTimeSeriesTest { private String adID; private String featureId; private String featureName; - private ADCircuitBreakerService adCircuitBreakerService; - private ADStats adStats; + private CircuitBreakerService adCircuitBreakerService; + private Stats adStats; private double confidence; private double anomalyGrade; private ADTaskManager adTaskManager; @@ -171,13 +170,13 @@ public void setUp() throws Exception { super.setUp(); super.setUpLog4jForJUnit(AnomalyResultTransportAction.class); - setupTestNodes(AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.PAGE_SIZE); + setupTestNodes(AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.AD_PAGE_SIZE); transportService = testNodes[0].transportService; clusterService = testNodes[0].clusterService; settings = clusterService.getSettings(); - stateManager = mock(NodeStateManager.class); + stateManager = mock(ADNodeStateManager.class); when(stateManager.isMuted(any(String.class), any(String.class))).thenReturn(false); when(stateManager.markColdStartRunning(anyString())).thenReturn(() -> {}); @@ -197,12 +196,12 @@ public void setUp() throws Exception { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(stateManager).getConfig(any(String.class), any(ActionListener.class)); when(detector.getIntervalInMinutes()).thenReturn(1L); hashRing = mock(HashRing.class); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); featureQuery = mock(FeatureManager.class); @@ -215,7 +214,7 @@ public void setUp() throws Exception { double rcfScore = 0.2; confidence = 0.91; anomalyGrade = 0.5; - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); long totalUpdates = 1440; int relativeIndex = 0; double[] currentTimeAttribution = new double[] { 0.5, 0.5 }; @@ -252,7 +251,7 @@ public void setUp() throws Exception { thresholdModelID = SingleStreamModelIdMapper.getThresholdModelId(adID); // "123-threshold"; // when(normalModelPartitioner.getThresholdModelId(any(String.class))).thenReturn(thresholdModelID); - adCircuitBreakerService = mock(ADCircuitBreakerService.class); + adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); ThreadPool threadPool = mock(ThreadPool.class); @@ -282,21 +281,21 @@ public void setUp() throws Exception { return null; }).when(client).index(any(), any()); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, settings); indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; - adStats = new ADStats(statsMap); + adStats = new Stats(statsMap); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -308,7 +307,6 @@ public void setUp() throws Exception { DetectorInternalState.Builder result = new DetectorInternalState.Builder().lastUpdateTime(Instant.now()); listener.onResponse(TestHelpers.createGetResponse(result.build(), detector.getId(), ADCommonName.DETECTION_STATE_INDEX)); - } return null; @@ -321,7 +319,7 @@ public void setUp() throws Exception { return null; }) .when(adTaskManager) - .initRealtimeTaskCacheAndCleanupStaleCache( + .initCacheWithCleanupIfRequired( anyString(), any(AnomalyDetector.class), any(TransportService.class), @@ -459,13 +457,13 @@ public void sendRequest( setupTestNodes( failureTransportInterceptor, Settings.EMPTY, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - AnomalyDetectorSettings.PAGE_SIZE + AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.AD_PAGE_SIZE ); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handler on testNodes[1] new RCFResultTransportAction( @@ -514,7 +512,7 @@ public void noModelExceptionTemplate(Exception exception, String adID, String er @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringColdStart() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(ResourceNotFoundException.class) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -562,7 +560,7 @@ public void testInsufficientCapacityExceptionDuringColdStart() { @SuppressWarnings("unchecked") public void testInsufficientCapacityExceptionDuringRestoringModel() { - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doThrow(new NotSerializableExceptionWrapper(new LimitExceededException(adID, CommonMessages.MEMORY_LIMIT_EXCEEDED_ERR_MSG))) .when(rcfManager) .getTRcfResult(any(String.class), any(String.class), any(double[].class), any(ActionListener.class)); @@ -682,13 +680,13 @@ public void sendRequest( setupTestNodes( failureTransportInterceptor, Settings.EMPTY, - AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY, - AnomalyDetectorSettings.PAGE_SIZE + AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, + AnomalyDetectorSettings.AD_PAGE_SIZE ); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor Optional discoveryNode = Optional.of(testNodes[1].discoveryNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(discoveryNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(discoveryNode); when(hashRing.getNodeByAddress(any(TransportAddress.class))).thenReturn(discoveryNode); // register handlers on testNodes[1] ActionFilters actionFilters = new ActionFilters(Collections.emptySet()); @@ -734,7 +732,7 @@ public void sendRequest( public void testCircuitBreaker() { - ADCircuitBreakerService breakerService = mock(ADCircuitBreakerService.class); + TimeSeriesCircuitBreakerService breakerService = mock(TimeSeriesCircuitBreakerService.class); when(breakerService.isOpen()).thenReturn(true); // These constructors register handler in transport service @@ -797,7 +795,7 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, .when(exceptionTransportService) .getConnection(same(rcfNode)); } else { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(thresholdModelID))).thenReturn(Optional.of(thresholdNode)); when(hashRing.getNodeByAddress(any())).thenReturn(Optional.of(thresholdNode)); doThrow(new NodeNotConnectedException(rcfNode, "rcf node not connected")) .when(exceptionTransportService) @@ -844,10 +842,10 @@ private void nodeNotConnectedExceptionTemplate(boolean isRCF, boolean temporary, assertException(listener, TimeSeriesException.class); if (!temporary) { - verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtimeAD(); + verify(hashRing, times(numberOfBuildCall)).buildCirclesForRealtime(); verify(stateManager, never()).addPressure(any(String.class), any(String.class)); } else { - verify(hashRing, never()).buildCirclesForRealtimeAD(); + verify(hashRing, never()).buildCirclesForRealtime(); verify(stateManager, times(numberOfBuildCall)).addPressure(any(String.class), any(String.class)); } } @@ -864,13 +862,13 @@ public void testTemporaryRCFNodeNotConnectedException() { @SuppressWarnings("unchecked") public void testMute() { - NodeStateManager muteStateManager = mock(NodeStateManager.class); + ADNodeStateManager muteStateManager = mock(ADNodeStateManager.class); when(muteStateManager.isMuted(any(String.class), any(String.class))).thenReturn(true); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(muteStateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(muteStateManager).getConfig(any(String.class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), transportService, @@ -894,7 +892,7 @@ public void testMute() { action.doExecute(null, request, listener); Throwable exception = assertException(listener, TimeSeriesException.class); - assertThat(exception.getMessage(), containsString(AnomalyResultTransportAction.NODE_UNRESPONSIVE_ERR_MSG)); + assertThat(exception.getMessage(), containsString(TimeSeriesResultProcessor.NODE_UNRESPONSIVE_ERR_MSG)); } public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOException { @@ -909,7 +907,7 @@ public void alertingRequestTemplate(boolean anomalyResultIndexExists) throws IOE ); Optional localNode = Optional.of(clusterService.state().nodes().getLocalNode()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(localNode); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(localNode); doReturn(localNode).when(hashRing).getNodeByAddress(any()); new ThresholdResultTransportAction(new ActionFilters(Collections.emptySet()), transportService, normalModelManager); @@ -970,7 +968,7 @@ public String executor() { } public void testSerialzationResponse() throws IOException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -995,7 +993,7 @@ public void testSerialzationResponse() throws IOException { } public void testJsonResponse() throws IOException, JsonPathNotFoundException { - AnomalyResultResponse response = new AnomalyResultResponse( + ResultResponse response = new AnomalyResultResponse( 4d, 0.993, 1.01, @@ -1053,7 +1051,7 @@ public void testSerialzationRequest() throws IOException { StreamInput streamInput = output.bytes().streamInput(); AnomalyResultRequest readRequest = new AnomalyResultRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + assertThat(request.getConfigID(), equalTo(readRequest.getConfigID())); assertThat(request.getStart(), equalTo(readRequest.getStart())); assertThat(request.getEnd(), equalTo(readRequest.getEnd())); } @@ -1064,7 +1062,7 @@ public void testJsonRequest() throws IOException, JsonPathNotFoundException { request.toXContent(builder, ToXContent.EMPTY_PARAMS); String json = Strings.toString(builder); - assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getAdID()); + assertEquals(JsonDeserializer.getTextValue(json, ADCommonName.ID_JSON_KEY), request.getConfigID()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.START_JSON_KEY), request.getStart()); assertEquals(JsonDeserializer.getLongValue(json, CommonName.END_JSON_KEY), request.getEnd()); } @@ -1481,17 +1479,17 @@ private void globalBlockTemplate(BlockType type, String errLogMsg) { } public void testReadBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_READ, TimeSeriesResultProcessor.READ_WRITE_BLOCKED); } public void testWriteBlock() { - globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, AnomalyResultTransportAction.READ_WRITE_BLOCKED); + globalBlockTemplate(BlockType.GLOBAL_BLOCK_WRITE, TimeSeriesResultProcessor.READ_WRITE_BLOCKED); } public void testIndexReadBlock() { globalBlockTemplate( BlockType.INDEX_BLOCK, - AnomalyResultTransportAction.INDEX_READ_BLOCKED, + TimeSeriesResultProcessor.INDEX_READ_BLOCKED, Settings.builder().put(IndexMetadata.INDEX_BLOCKS_READ_SETTING.getKey(), true).build(), "test1" ); @@ -1521,7 +1519,7 @@ public void testNullRCFResult() { "123-rcf-0", null, "123", null, mock(ActionListener.class), null, null ); listener.onResponse(null); - assertTrue(testAppender.containsMessage(AnomalyResultTransportAction.NULL_RESPONSE)); + assertTrue(testAppender.containsMessage(TimeSeriesResultProcessor.NULL_RESPONSE)); } @SuppressWarnings("unchecked") @@ -1600,7 +1598,7 @@ public void testAllFeaturesDisabled() throws IOException { ActionListener> listener = invocation.getArgument(1); listener.onFailure(new EndRunException(adID, CommonMessages.ALL_FEATURES_DISABLED_ERR_MSG, true)); return null; - }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(stateManager).getConfig(any(String.class), any(ActionListener.class)); AnomalyResultTransportAction action = new AnomalyResultTransportAction( new ActionFilters(Collections.emptySet()), @@ -1633,7 +1631,7 @@ public void testEndRunDueToNoTrainingData() { ThreadPool mockThreadPool = mock(ThreadPool.class); setUpColdStart(mockThreadPool, new ColdStartConfig.Builder().coldStartRunning(false).build()); - ModelManager rcfManager = mock(ModelManager.class); + ADModelManager rcfManager = mock(ADModelManager.class); doAnswer(invocation -> { Object[] args = invocation.getArguments(); ActionListener listener = (ActionListener) args[3]; diff --git a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java index 0cd9218f0..78ffca8dd 100644 --- a/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/AnomalyResultTransportActionTests.java @@ -25,7 +25,6 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.ad.ADIntegTestCase; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.util.ExceptionUtil; import org.opensearch.core.common.io.stream.NotSerializableExceptionWrapper; import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.test.rest.OpenSearchRestTestCase; @@ -33,6 +32,7 @@ import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.util.ExceptionUtil; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; diff --git a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java index a65c35839..bbd4251f0 100644 --- a/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/CronTransportActionTests.java @@ -23,13 +23,10 @@ import org.junit.Before; import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.common.exception.JsonPathNotFoundException; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -42,6 +39,9 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.transport.TransportService; import test.org.opensearch.ad.util.JsonDeserializer; @@ -65,12 +65,12 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); - NodeStateManager tarnsportStatemanager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADNodeStateManager tarnsportStatemanager = mock(ADNodeStateManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); + HCCacheProvider cacheProvider = mock(HCCacheProvider.class); EntityCache entityCache = mock(EntityCache.class); - EntityColdStarter entityColdStarter = mock(EntityColdStarter.class); + ADEntityColdStart entityColdStarter = mock(ADEntityColdStart.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskManager adTaskManager = mock(ADTaskManager.class); diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java index ca7fae8ba..02fa3ea45 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorActionTests.java @@ -49,7 +49,7 @@ public void setUp() throws Exception { ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); adTaskManager = mock(ADTaskManager.class); diff --git a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java index 7b67843f1..02d260d47 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteAnomalyDetectorTests.java @@ -21,7 +21,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.Map; import java.util.Optional; import java.util.function.Consumer; @@ -37,7 +36,6 @@ import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.client.Client; @@ -58,6 +56,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.function.ExecutorFunction; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; @@ -71,7 +70,7 @@ public class DeleteAnomalyDetectorTests extends AbstractTimeSeriesTest { private DeleteResponse deleteResponse; private GetResponse getResponse; ClusterService clusterService; - private AnomalyDetectorJob jobParameter; + private Job jobParameter; @BeforeClass public static void setUpBeforeClass() { @@ -89,7 +88,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); transportService = new TransportService( @@ -117,7 +116,7 @@ public void setUp() throws Exception { adTaskManager ); - jobParameter = mock(AnomalyDetectorJob.class); + jobParameter = mock(Job.class); when(jobParameter.getName()).thenReturn(randomAlphaOfLength(10)); IntervalSchedule schedule = new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES); when(jobParameter.getSchedule()).thenReturn(schedule); @@ -165,7 +164,7 @@ public void testDeleteADTransportAction_LatestDetectorLevelTask() { ADTask adTask = ADTask.builder().state("RUNNING").build(); consumer.accept(Optional.of(adTask)); return null; - }).when(adTaskManager).getAndExecuteOnLatestDetectorLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); + }).when(adTaskManager).getAndExecuteOnLatestConfigLevelTask(eq("1234"), any(), any(), eq(transportService), eq(true), any()); future = mock(PlainActionFuture.class); DeleteAnomalyDetectorRequest request = new DeleteAnomalyDetectorRequest("1234"); @@ -197,9 +196,9 @@ public void testDeleteADTransportAction_GetResponseException() { } private ClusterState createClusterState() { - Map immutableOpenMap = new HashMap<>(); - immutableOpenMap - .put( + ImmutableOpenMap immutableOpenMap = ImmutableOpenMap + .builder() + .fPut( CommonName.JOB_INDEX, IndexMetadata .builder("test") @@ -288,7 +287,7 @@ private void setupMocks( true, BytesReference .bytes( - new AnomalyDetectorJob( + new Job( "1234", jobParameter.getSchedule(), jobParameter.getWindowDelay(), diff --git a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java index aeb0b7165..adf122595 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteITTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteITTests.java @@ -20,6 +20,11 @@ import org.opensearch.ad.ADIntegTestCase; import org.opensearch.plugins.Plugin; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; + public class DeleteITTests extends ADIntegTestCase { @@ -28,23 +33,24 @@ protected Collection> nodePlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } + @Override protected Collection> transportClientPlugins() { return Collections.singletonList(TimeSeriesAnalyticsPlugin.class); } public void testNormalStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); - StopDetectorResponse response = future.get(); + StopConfigResponse response = future.get(); assertTrue(response.success()); } public void testNormalDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest("123"); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); DeleteModelResponse response = future.get(); assertTrue(!response.hasFailures()); @@ -53,15 +59,15 @@ public void testNormalDeleteModel() throws ExecutionException, InterruptedExcept public void testEmptyIDDeleteModel() throws ExecutionException, InterruptedException { DeleteModelRequest request = new DeleteModelRequest(""); - ActionFuture future = client().execute(DeleteModelAction.INSTANCE, request); + ActionFuture future = client().execute(DeleteADModelAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } public void testEmptyIDStopDetector() throws ExecutionException, InterruptedException { - StopDetectorRequest request = new StopDetectorRequest(); + StopConfigRequest request = new StopConfigRequest(); - ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); + ActionFuture future = client().execute(StopDetectorAction.INSTANCE, request); expectThrows(ActionRequestValidationException.class, () -> future.actionGet()); } diff --git a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java index fd74a2802..ddb4cda3a 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteModelTransportActionTests.java @@ -27,14 +27,11 @@ import org.opensearch.Version; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -47,6 +44,13 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.caching.HCCacheProvider; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.DeleteModelNodeRequest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; import org.opensearch.transport.TransportService; import test.org.opensearch.ad.util.JsonDeserializer; @@ -54,7 +58,7 @@ import com.google.gson.JsonElement; public class DeleteModelTransportActionTests extends AbstractTimeSeriesTest { - private DeleteModelTransportAction action; + private DeleteADModelTransportAction action; private String localNodeID; @Override @@ -70,16 +74,16 @@ public void setUp() throws Exception { TransportService transportService = mock(TransportService.class); ActionFilters actionFilters = mock(ActionFilters.class); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); - ModelManager modelManager = mock(ModelManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); + ADModelManager modelManager = mock(ADModelManager.class); FeatureManager featureManager = mock(FeatureManager.class); - CacheProvider cacheProvider = mock(CacheProvider.class); + HCCacheProvider cacheProvider = mock(HCCacheProvider.class); EntityCache entityCache = mock(EntityCache.class); when(cacheProvider.get()).thenReturn(entityCache); ADTaskCacheManager adTaskCacheManager = mock(ADTaskCacheManager.class); - EntityColdStarter coldStarter = mock(EntityColdStarter.class); + ADEntityColdStart coldStarter = mock(ADEntityColdStart.class); - action = new DeleteModelTransportAction( + action = new DeleteADModelTransportAction( threadPool, clusterService, transportService, diff --git a/src/test/java/org/opensearch/ad/transport/DeleteTests.java b/src/test/java/org/opensearch/ad/transport/DeleteTests.java index 619ee6bb2..42bdb0a41 100644 --- a/src/test/java/org/opensearch/ad/transport/DeleteTests.java +++ b/src/test/java/org/opensearch/ad/transport/DeleteTests.java @@ -59,6 +59,11 @@ import org.opensearch.tasks.Task; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; +import org.opensearch.timeseries.transport.DeleteModelNodeResponse; +import org.opensearch.timeseries.transport.DeleteModelRequest; +import org.opensearch.timeseries.transport.DeleteModelResponse; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.TransportService; @@ -140,7 +145,7 @@ public void testSerialzationResponse() throws IOException { response.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - DeleteModelResponse readResponse = DeleteModelAction.INSTANCE.getResponseReader().read(streamInput); + DeleteModelResponse readResponse = DeleteADModelAction.INSTANCE.getResponseReader().read(streamInput); assertTrue(readResponse.hasFailures()); assertEquals(failures.size(), readResponse.failures().size()); @@ -153,12 +158,12 @@ public void testEmptyIDDeleteModel() { } public void testEmptyIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().validate(); + ActionRequestValidationException e = new StopConfigRequest().validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testValidIDStopDetector() { - ActionRequestValidationException e = new StopDetectorRequest().adID("foo").validate(); + ActionRequestValidationException e = new StopConfigRequest().adID("foo").validate(); assertThat(e, is(nullValue())); } @@ -172,12 +177,12 @@ public void testSerialzationRequestDeleteModel() throws IOException { } public void testSerialzationRequestStopDetector() throws IOException { - StopDetectorRequest request = new StopDetectorRequest().adID("123"); + StopConfigRequest request = new StopConfigRequest().adID("123"); BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); StreamInput streamInput = output.bytes().streamInput(); - StopDetectorRequest readRequest = new StopDetectorRequest(streamInput); - assertThat(request.getAdID(), equalTo(readRequest.getAdID())); + StopConfigRequest readRequest = new StopConfigRequest(streamInput); + assertThat(request.getConfigID(), equalTo(readRequest.getConfigID())); } public void testJsonRequestTemplate(R request, Supplier requestSupplier) throws IOException, @@ -190,8 +195,8 @@ public void testJsonRequestTemplate(R request, Supplier listener = new PlainActionFuture<>(); + StopConfigRequest request = new StopConfigRequest().adID(detectorID); + PlainActionFuture listener = new PlainActionFuture<>(); action.doExecute(task, request, listener); - StopDetectorResponse response = listener.actionGet(); + StopConfigResponse response = listener.actionGet(); assertTrue(!response.success()); } diff --git a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java index 4a1cfc718..397411322 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityProfileTests.java @@ -31,14 +31,10 @@ import org.opensearch.Version; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.EntityProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.model.ModelProfileOnNode; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -49,9 +45,12 @@ import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.caching.EntityCache; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -77,7 +76,7 @@ public class EntityProfileTests extends AbstractTimeSeriesTest { private TransportService transportService; private Settings settings; private ClusterService clusterService; - private CacheProvider cacheProvider; + private EntityCacheProvider cacheProvider; private EntityProfileTransportAction action; private Task task; private PlainActionFuture future; @@ -133,12 +132,12 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(EntityCacheProvider.class); EntityCache cache = mock(EntityCache.class); updates = 1L; when(cache.getTotalUpdates(anyString(), anyString())).thenReturn(updates); when(cache.isActive(anyString(), anyString())).thenReturn(isActive); - when(cache.getLastActiveMs(anyString(), anyString())).thenReturn(lastActiveTimestamp); + when(cache.getLastActiveTime(anyString(), anyString())).thenReturn(lastActiveTimestamp); Map modelSizeMap = new HashMap<>(); modelSizeMap.put(modelId, modelSize); when(cache.getModelSize(anyString())).thenReturn(modelSizeMap); @@ -265,7 +264,7 @@ private void registerHandler(FakeNode node) { } public void testInvalidRequest() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.empty()); action.doExecute(task, request, future); assertException(future, TimeSeriesException.class, EntityProfileTransportAction.NO_NODE_FOUND_MSG); @@ -273,7 +272,7 @@ public void testInvalidRequest() { public void testLocalNodeHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); action.doExecute(task, request, future); @@ -283,7 +282,7 @@ public void testLocalNodeHit() { public void testAllHit() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(anyString())).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(anyString())).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); request = new EntityProfileRequest(detectorId, entity, all); @@ -309,7 +308,7 @@ public void testGetRemoteUpdateResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -339,7 +338,7 @@ public void testGetRemoteFailureResponse() { cacheProvider ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); diff --git a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java index 30177220b..f40b2543d 100644 --- a/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/EntityResultTransportActionTests.java @@ -49,30 +49,19 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.AnomalyDetectorJobRunnerTests; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.constant.CommonValue; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.CheckpointDao; -import org.opensearch.ad.ml.EntityColdStarter; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager; -import org.opensearch.ad.ml.ModelState; +import org.opensearch.ad.ml.ADCheckpointDao; +import org.opensearch.ad.ml.ADEntityColdStart; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Strings; import org.opensearch.common.settings.ClusterSettings; @@ -88,6 +77,7 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import org.opensearch.transport.TransportService; import test.org.opensearch.ad.util.JsonDeserializer; @@ -101,15 +91,15 @@ public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { EntityResultTransportAction entityResult; ActionFilters actionFilters; TransportService transportService; - ModelManager manager; - ADCircuitBreakerService adCircuitBreakerService; - CheckpointDao checkpointDao; - CacheProvider provider; + ADModelManager manager; + CircuitBreakerService adCircuitBreakerService; + ADCheckpointDao checkpointDao; + EntityCacheProvider provider; EntityCache entityCache; - NodeStateManager stateManager; + ADNodeStateManager stateManager; Settings settings; Clock clock; - EntityResultRequest request; + EntityADResultRequest request; String detectorId; long timeoutMs; AnomalyDetector detector; @@ -124,13 +114,13 @@ public class EntityResultTransportActionTests extends AbstractTimeSeriesTest { double[] cacheHitData; String tooLongEntity; double[] tooLongData; - ResultWriteWorker resultWriteQueue; - CheckpointReadWorker checkpointReadQueue; + ADResultWriteWorker resultWriteQueue; + ADCheckpointReadWorker checkpointReadQueue; int minSamples; Instant now; - EntityColdStarter coldStarter; - ColdEntityWorker coldEntityQueue; - EntityColdStartWorker entityColdStartQueue; + ADEntityColdStart coldStarter; + ADColdEntityWorker coldEntityQueue; + ADColdStartWorker entityColdStartQueue; ADIndexManagement indexUtil; ClusterService clusterService; ADStats adStats; @@ -153,17 +143,17 @@ public void setUp() throws Exception { actionFilters = mock(ActionFilters.class); transportService = mock(TransportService.class); - adCircuitBreakerService = mock(ADCircuitBreakerService.class); + adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - checkpointDao = mock(CheckpointDao.class); + checkpointDao = mock(ADCheckpointDao.class); detectorId = "123"; entities = new HashMap<>(); start = 10L; end = 20L; - request = new EntityResultRequest(detectorId, entities, start, end); + request = new EntityADResultRequest(detectorId, entities, start, end); clock = mock(Clock.class); now = Instant.now(); @@ -171,7 +161,7 @@ public void setUp() throws Exception { settings = Settings .builder() - .put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)) + .put(AnomalyDetectorSettings.AD_COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)) .put(AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ.getKey(), TimeValue.timeValueHours(12)) .build(); @@ -181,7 +171,7 @@ public void setUp() throws Exception { Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - manager = new ModelManager( + manager = new ADModelManager( null, clock, 0, @@ -192,25 +182,25 @@ public void setUp() throws Exception { 0, null, AnomalyDetectorSettings.CHECKPOINT_SAVING_FREQ, - mock(EntityColdStarter.class), + mock(ADEntityColdStart.class), null, null, settings, clusterService ); - provider = mock(CacheProvider.class); + provider = mock(EntityCacheProvider.class); entityCache = mock(EntityCache.class); when(provider.get()).thenReturn(entityCache); String field = "a"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Arrays.asList(field)); - stateManager = mock(NodeStateManager.class); + stateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(stateManager).getConfig(any(String.class), any(ActionListener.class)); cacheMissEntity = "0.0.0.1"; cacheMissData = new double[] { 0.1 }; @@ -224,7 +214,8 @@ public void setUp() throws Exception { tooLongData = new double[] { 0.3 }; entities.put(Entity.createSingleAttributeEntity(detector.getCategoryFields().get(0), tooLongEntity), tooLongData); - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(true).build()); when(entityCache.get(eq(cacheMissEntityObj.getModelId(detectorId).get()), any())).thenReturn(null); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); @@ -235,31 +226,31 @@ public void setUp() throws Exception { indexUtil = mock(ADIndexManagement.class); when(indexUtil.getSchemaVersion(any())).thenReturn(CommonValue.NO_SCHEMA_VERSION); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); minSamples = 1; - coldStarter = mock(EntityColdStarter.class); + coldStarter = mock(ADEntityColdStart.class); doAnswer(invocation -> { - ModelState modelState = invocation.getArgument(0); + ADModelState modelState = invocation.getArgument(0); modelState.getModel().clear(); return null; }).when(coldStarter).trainModelFromExistingSamples(any(), anyInt()); - coldEntityQueue = mock(ColdEntityWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); - entityResult = new EntityResultTransportAction( + entityResult = new EntityADResultTransportAction( actionFilters, transportService, manager, @@ -305,7 +296,7 @@ public void testFailtoGetDetector() { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.empty()); return null; - }).when(stateManager).getAnomalyDetector(any(String.class), any(ActionListener.class)); + }).when(stateManager).getConfig(any(String.class), any(ActionListener.class)); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -316,7 +307,8 @@ public void testFailtoGetDetector() { // test rcf score is 0 public void testNoResultsToSave() { - ModelState state = MLUtil.randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); + ADModelState state = MLUtil + .randomModelState(new RandomModelStateConfig.Builder().fullModel(false).build()); when(entityCache.get(eq(cacheHitEntityObj.getModelId(detectorId).get()), any())).thenReturn(state); PlainActionFuture future = PlainActionFuture.newFuture(); @@ -334,19 +326,19 @@ public void testValidRequest() { } public void testEmptyId() { - request = new EntityResultRequest("", entities, start, end); + request = new EntityADResultRequest("", entities, start, end); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(ADCommonMessages.AD_ID_MISSING_MSG)); } public void testReverseTime() { - request = new EntityResultRequest(detectorId, entities, end, start); + request = new EntityADResultRequest(detectorId, entities, end, start); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } public void testNegativeTime() { - request = new EntityResultRequest(detectorId, entities, start, -end); + request = new EntityADResultRequest(detectorId, entities, start, -end); ActionRequestValidationException e = request.validate(); assertThat(e.validationErrors(), hasItem(startsWith(CommonMessages.INVALID_TIMESTAMP_ERR_MSG))); } @@ -383,9 +375,9 @@ public void testJsonResponse() throws IOException, JsonPathNotFoundException { } public void testFailToScore() { - ModelManager spyModelManager = spy(manager); - doThrow(new IllegalArgumentException()).when(spyModelManager).getAnomalyResultForEntity(any(), any(), anyString(), any(), anyInt()); - entityResult = new EntityResultTransportAction( + ADModelManager spyModelManager = spy(manager); + doThrow(new IllegalArgumentException()).when(spyModelManager).getResult(any(), any(), anyString(), any(), anyInt()); + entityResult = new EntityADResultTransportAction( actionFilters, transportService, spyModelManager, @@ -408,9 +400,9 @@ public void testFailToScore() { future.actionGet(timeoutMs); verify(resultWriteQueue, never()).put(any()); - verify(entityCache, times(1)).removeEntityModel(anyString(), anyString()); + verify(entityCache, times(1)).removeModel(anyString(), anyString()); verify(entityColdStartQueue, times(1)).put(any()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); } } diff --git a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java index f2da82c36..254dfb93f 100644 --- a/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ForwardADTaskTransportActionTests.java @@ -31,15 +31,16 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.ADUnitTestCase; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; import org.opensearch.tasks.Task; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.feature.FeatureManager; +import org.opensearch.timeseries.transport.JobResponse; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableList; @@ -50,10 +51,10 @@ public class ForwardADTaskTransportActionTests extends ADUnitTestCase { private ADTaskManager adTaskManager; private ADTaskCacheManager adTaskCacheManager; private FeatureManager featureManager; - private NodeStateManager stateManager; + private ADNodeStateManager stateManager; private ForwardADTaskTransportAction forwardADTaskTransportAction; private Task task; - private ActionListener listener; + private ActionListener listener; @SuppressWarnings("unchecked") @Override @@ -64,7 +65,7 @@ public void setUp() throws Exception { adTaskManager = mock(ADTaskManager.class); adTaskCacheManager = mock(ADTaskCacheManager.class); featureManager = mock(FeatureManager.class); - stateManager = mock(NodeStateManager.class); + stateManager = mock(ADNodeStateManager.class); forwardADTaskTransportAction = new ForwardADTaskTransportAction( actionFilters, transportService, @@ -88,7 +89,7 @@ public void testCheckAvailableTaskSlots() throws IOException { public void testNextEntityTaskForSingleEntityDetector() throws IOException { when(adTaskCacheManager.hasEntity(anyString())).thenReturn(false); - ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR); ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, NEXT_ENTITY); forwardADTaskTransportAction.doExecute(task, request, listener); verify(listener, times(1)).onFailure(any()); @@ -115,7 +116,7 @@ public void testNextEntityTaskWithPendingEntity() throws IOException { } public void testPushBackEntityForSingleEntityDetector() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR); ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, PUSH_BACK_ENTITY); forwardADTaskTransportAction.doExecute(task, request, listener); verify(listener, times(1)).onFailure(any()); @@ -219,7 +220,7 @@ public void testScaleEntityTaskSlotsWithAvailableSlots() throws IOException { } public void testCancelSingleEntityDetector() throws IOException { - ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_ENTITY); + ADTask adTask = TestHelpers.randomAdTask(ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR); ForwardADTaskRequest request = new ForwardADTaskRequest(adTask, CANCEL); forwardADTaskTransportAction.doExecute(task, request, listener); verify(listener, times(1)).onFailure(any()); diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java index 60144c63c..fb1b930ea 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorActionTests.java @@ -20,11 +20,12 @@ import org.mockito.Mockito; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.DetectorProfile; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.rest.RestStatus; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -40,11 +41,11 @@ public void setUp() throws Exception { @Test public void testGetRequest() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, false, false, "nonempty", "", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); } @@ -52,7 +53,7 @@ public void testGetRequest() throws IOException { public void testGetResponse() throws Exception { BytesStreamOutput out = new BytesStreamOutput(); AnomalyDetector detector = Mockito.mock(AnomalyDetector.class); - AnomalyDetectorJob detectorJob = Mockito.mock(AnomalyDetectorJob.class); + Job detectorJob = Mockito.mock(Job.class); Mockito.doNothing().when(detector).writeTo(out); GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( 1234, diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java index 4a3f2a89c..9c7d132bb 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTests.java @@ -40,14 +40,15 @@ import org.opensearch.action.get.MultiGetResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.ADCommonMessages; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; import org.opensearch.ad.model.ADTaskType; +import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.SecurityClientUtil; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -58,18 +59,20 @@ import org.opensearch.timeseries.constant.CommonMessages; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; public class GetAnomalyDetectorTests extends AbstractTimeSeriesTest { - private GetAnomalyDetectorTransportAction action; + private BaseGetConfigTransportAction action; private TransportService transportService; private DiscoveryNodeFilterer nodeFilter; private ActionFilters actionFilters; private Client client; private SecurityClientUtil clientUtil; - private GetAnomalyDetectorRequest request; + private GetConfigRequest request; private String detectorId = "yecrdnUBqurvo9uKU_d8"; private String entityValue = "app_0"; private String categoryField = "categoryField"; @@ -95,7 +98,7 @@ public void setUp() throws Exception { ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -119,7 +122,7 @@ public void setUp() throws Exception { Clock clock = mock(Clock.class); Throttler throttler = new Throttler(clock); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); adTaskManager = mock(ADTaskManager.class); @@ -144,7 +147,7 @@ public void testInvalidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -169,7 +172,7 @@ public void testValidRequest() throws IOException { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_/_profile"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, false, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.doExecute(null, request, future); @@ -185,17 +188,7 @@ public void testGetTransportActionWithReturnTask() { return null; }) .when(adTaskManager) - .getAndExecuteOnLatestADTasks( - anyString(), - eq(null), - eq(null), - anyList(), - any(), - eq(transportService), - eq(true), - anyInt(), - any() - ); + .getAndExecuteOnLatestTasks(anyString(), eq(null), eq(null), anyList(), any(), eq(transportService), eq(true), anyInt(), any()); doAnswer(invocation -> { Object[] args = invocation.getArguments(); @@ -207,7 +200,7 @@ public void testGetTransportActionWithReturnTask() { rawPath = "_opendistro/_anomaly_detection/detectors/T4c3dXUBj-2IZN7itix_"; - request = new GetAnomalyDetectorRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); + request = new GetConfigRequest(detectorId, 0L, false, true, typeStr, rawPath, false, entity); future = new PlainActionFuture<>(); action.getExecute(request, future); @@ -233,11 +226,11 @@ private MultiGetResponse createMultiGetResponse() { } private List createADTaskList() { - ADTask adTask1 = new ADTask.Builder().taskId("test1").taskType(ADTaskType.REALTIME_SINGLE_ENTITY.name()).build(); - ADTask adTask2 = new ADTask.Builder().taskId("test2").taskType(ADTaskType.REALTIME_SINGLE_ENTITY.name()).build(); + ADTask adTask1 = new ADTask.Builder().taskId("test1").taskType(ADTaskType.REALTIME_SINGLE_STREAM_DETECTOR.name()).build(); + ADTask adTask2 = new ADTask.Builder().taskId("test2").taskType(ADTaskType.REALTIME_SINGLE_STREAM_DETECTOR.name()).build(); ADTask adTask3 = new ADTask.Builder().taskId("test3").taskType(ADTaskType.REALTIME_HC_DETECTOR.name()).build(); ADTask adTask4 = new ADTask.Builder().taskId("test4").taskType(ADTaskType.HISTORICAL_HC_DETECTOR.name()).build(); - ADTask adTask5 = new ADTask.Builder().taskId("test5").taskType(ADTaskType.HISTORICAL_SINGLE_ENTITY.name()).build(); + ADTask adTask5 = new ADTask.Builder().taskId("test5").taskType(ADTaskType.HISTORICAL_SINGLE_STREAM_DETECTOR.name()).build(); return Arrays.asList(adTask1, adTask2, adTask3, adTask4, adTask5); } diff --git a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java index 34f1485c2..4d408fd65 100644 --- a/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/GetAnomalyDetectorTransportActionTests.java @@ -26,16 +26,17 @@ import org.mockito.Mockito; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.ADCommonName; +import org.opensearch.ad.indices.ADIndex; +import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.EntityProfile; import org.opensearch.ad.model.InitProgressProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.*; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.settings.ClusterSettings; @@ -50,17 +51,22 @@ import org.opensearch.test.OpenSearchSingleNodeTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.NodeStateManager; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.Job; +import org.opensearch.timeseries.transport.BaseGetConfigTransportAction; +import org.opensearch.timeseries.transport.GetConfigRequest; import org.opensearch.timeseries.util.DiscoveryNodeFilterer; import org.opensearch.timeseries.util.RestHandlerUtils; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; public class GetAnomalyDetectorTransportActionTests extends OpenSearchSingleNodeTestCase { private static ThreadPool threadPool; - private GetAnomalyDetectorTransportAction action; + private BaseGetConfigTransportAction action; private Task task; private ActionListener response; private ADTaskManager adTaskManager; @@ -85,7 +91,7 @@ public void setUp() throws Exception { ClusterService clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); adTaskManager = mock(ADTaskManager.class); @@ -125,31 +131,13 @@ protected NamedWriteableRegistry writableRegistry() { @Test public void testGetTransportAction() throws IOException { - GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( - "1234", - 4321, - false, - false, - "nonempty", - "", - false, - null - ); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, false, false, "nonempty", "", false, null); action.doExecute(task, getAnomalyDetectorRequest, response); } @Test public void testGetTransportActionWithReturnJob() throws IOException { - GetAnomalyDetectorRequest getAnomalyDetectorRequest = new GetAnomalyDetectorRequest( - "1234", - 4321, - true, - false, - "", - "abcd", - false, - null - ); + GetConfigRequest getAnomalyDetectorRequest = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); action.doExecute(task, getAnomalyDetectorRequest, response); } @@ -161,23 +149,23 @@ public void testGetAction() { @Test public void testGetAnomalyDetectorRequest() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, entity); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, entity); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); - Assert.assertEquals(request.getDetectorID(), newRequest.getDetectorID()); + GetConfigRequest newRequest = new GetConfigRequest(input); + Assert.assertEquals(request.getConfigID(), newRequest.getConfigID()); Assert.assertEquals(request.getRawPath(), newRequest.getRawPath()); Assert.assertNull(newRequest.validate()); } @Test public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { - GetAnomalyDetectorRequest request = new GetAnomalyDetectorRequest("1234", 4321, true, false, "", "abcd", false, null); + GetConfigRequest request = new GetConfigRequest("1234", 4321, true, false, "", "abcd", false, null); BytesStreamOutput out = new BytesStreamOutput(); request.writeTo(out); StreamInput input = out.bytes().streamInput(); - GetAnomalyDetectorRequest newRequest = new GetAnomalyDetectorRequest(input); + GetConfigRequest newRequest = new GetConfigRequest(input); Assert.assertNull(newRequest.getEntity()); } @@ -186,7 +174,7 @@ public void testGetAnomalyDetectorRequestNoEntityValue() throws IOException { public void testGetAnomalyDetectorResponse() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); - AnomalyDetectorJob adJob = TestHelpers.randomAnomalyDetectorJob(); + Job adJob = TestHelpers.randomAnomalyDetectorJob(); GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( 4321, "1234", @@ -220,7 +208,7 @@ public void testGetAnomalyDetectorResponse() throws IOException { public void testGetAnomalyDetectorProfileResponse() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); AnomalyDetector detector = TestHelpers.randomAnomalyDetector(ImmutableMap.of("testKey", "testValue"), Instant.now()); - AnomalyDetectorJob adJob = TestHelpers.randomAnomalyDetectorJob(); + Job adJob = TestHelpers.randomAnomalyDetectorJob(); InitProgressProfile initProgress = new InitProgressProfile("99%", 2L, 2); EntityProfile entityProfile = new EntityProfile.Builder().initProgress(initProgress).build(); GetAnomalyDetectorResponse response = new GetAnomalyDetectorResponse( diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java index f29030912..2b3a5aa2e 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorActionTests.java @@ -34,6 +34,7 @@ import com.google.common.collect.ImmutableMap; public class IndexAnomalyDetectorActionTests extends OpenSearchSingleNodeTestCase { + @Override @Before public void setUp() throws Exception { super.setUp(); diff --git a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java index 0a3859bc2..efff54df8 100644 --- a/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/IndexAnomalyDetectorTransportActionTests.java @@ -19,10 +19,8 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.Locale; -import java.util.Map; import org.junit.Assert; import org.junit.Before; @@ -37,13 +35,15 @@ import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.feature.SearchFeatureDao; +import org.opensearch.ad.ADNodeStateManager; +import org.opensearch.ad.indices.ADIndex; import org.opensearch.ad.indices.ADIndexManagement; +import org.opensearch.ad.model.ADTask; +import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.settings.AnomalyDetectorSettings; +import org.opensearch.ad.task.ADTaskCacheManager; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; @@ -63,6 +63,9 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; +import org.opensearch.timeseries.feature.SearchFeatureDao; +import org.opensearch.timeseries.task.TaskManager; +import org.opensearch.timeseries.util.SecurityClientUtil; import org.opensearch.transport.TransportService; import com.google.common.collect.ImmutableMap; @@ -74,7 +77,7 @@ public class IndexAnomalyDetectorTransportActionTests extends OpenSearchIntegTes private ActionListener response; private ClusterService clusterService; private ClusterSettings clusterSettings; - private ADTaskManager adTaskManager; + private TaskManager adTaskManager; private Client client = mock(Client.class); private SecurityClientUtil clientUtil; private SearchFeatureDao searchFeatureDao; @@ -87,7 +90,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); @@ -100,14 +103,16 @@ public void setUp() throws Exception { .build(); final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); IndexMetadata indexMetaData = IndexMetadata.builder(CommonName.CONFIG_INDEX).settings(existingSettings).build(); - final Map indices = new HashMap<>(); - indices.put(CommonName.CONFIG_INDEX, indexMetaData); + final ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(CommonName.CONFIG_INDEX, indexMetaData) + .build(); ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); when(clusterService.state()).thenReturn(clusterState); adTaskManager = mock(ADTaskManager.class); searchFeatureDao = mock(SearchFeatureDao.class); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); action = new IndexAnomalyDetectorTransportAction( mock(TransportService.class), @@ -199,7 +204,7 @@ public void testIndexTransportAction() { @Test public void testIndexTransportActionWithUserAndFilterOn() { - Settings settings = Settings.builder().put(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.getKey(), true).build(); ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); diff --git a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java index 9ec5aa9d5..7d03a90a0 100644 --- a/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/MultiEntityResultTests.java @@ -24,10 +24,10 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.BACKOFF_MINUTES; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_ENTITIES_PER_QUERY; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.MAX_RETRY_FOR_UNRESPONSIVE_NODE; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.PAGE_SIZE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_BACKOFF_MINUTES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_PAGE_SIZE; import java.io.IOException; import java.time.Clock; @@ -69,30 +69,16 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.ad.NodeStateManager; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.cluster.HashRing; -import org.opensearch.ad.feature.CompositeRetriever; -import org.opensearch.ad.feature.FeatureManager; import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.ratelimit.CheckpointReadWorker; -import org.opensearch.ad.ratelimit.ColdEntityWorker; -import org.opensearch.ad.ratelimit.EntityColdStartWorker; -import org.opensearch.ad.ratelimit.EntityFeatureRequest; -import org.opensearch.ad.ratelimit.ResultWriteWorker; +import org.opensearch.ad.ratelimit.ADCheckpointReadWorker; +import org.opensearch.ad.ratelimit.ADColdEntityWorker; +import org.opensearch.ad.ratelimit.ADColdStartWorker; +import org.opensearch.ad.ratelimit.ADResultWriteWorker; import org.opensearch.ad.settings.AnomalyDetectorSettings; -import org.opensearch.ad.stats.ADStat; -import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.ad.task.ADTaskManager; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.SecurityClientUtil; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -123,7 +109,9 @@ import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.Entity; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportInterceptor; @@ -144,27 +132,27 @@ public class MultiEntityResultTests extends AbstractTimeSeriesTest { private TransportInterceptor entityResultInterceptor; private Clock clock; private AnomalyDetector detector; - private NodeStateManager stateManager; + private ADNodeStateManager stateManager; private static Settings settings; private TransportService transportService; private Client client; private SecurityClientUtil clientUtil; private FeatureManager featureQuery; - private ModelManager normalModelManager; + private ADModelManager normalModelManager; private HashRing hashRing; private ClusterService clusterService; private IndexNameExpressionResolver indexNameResolver; - private ADCircuitBreakerService adCircuitBreakerService; + private CircuitBreakerService adCircuitBreakerService; private ADStats adStats; private ThreadPool mockThreadPool; private String detectorId; private Instant now; - private CacheProvider provider; + private EntityCacheProvider provider; private ADIndexManagement indexUtil; - private ResultWriteWorker resultWriteQueue; - private CheckpointReadWorker checkpointReadQueue; - private EntityColdStartWorker entityColdStartQueue; - private ColdEntityWorker coldEntityQueue; + private ADResultWriteWorker resultWriteQueue; + private ADCheckpointReadWorker checkpointReadQueue; + private ADColdStartWorker entityColdStartQueue; + private ADColdEntityWorker coldEntityQueue; private String app0 = "app_0"; private String server1 = "server_1"; private String server2 = "server_2"; @@ -198,15 +186,15 @@ public void setUp() throws Exception { String categoryField = "a"; detector = TestHelpers.randomAnomalyDetectorUsingCategoryFields(detectorId, Collections.singletonList(categoryField)); - stateManager = mock(NodeStateManager.class); + stateManager = mock(ADNodeStateManager.class); // make sure parameters are not null, otherwise this mock won't get invoked doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(stateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + }).when(stateManager).getConfig(anyString(), any(ActionListener.class)); - settings = Settings.builder().put(AnomalyDetectorSettings.COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); + settings = Settings.builder().put(AnomalyDetectorSettings.AD_COOLDOWN_MINUTES.getKey(), TimeValue.timeValueMinutes(5)).build(); // make sure end time is larger enough than Clock.systemUTC().millis() to get PageIterator.hasNext() to pass request = new AnomalyResultRequest(detectorId, 100, Clock.systemUTC().millis() + 100_000); @@ -223,15 +211,15 @@ public void setUp() throws Exception { featureQuery = mock(FeatureManager.class); - normalModelManager = mock(ModelManager.class); + normalModelManager = mock(ADModelManager.class); hashRing = mock(HashRing.class); Set> anomalyResultSetting = new HashSet<>(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); - anomalyResultSetting.add(MAX_ENTITIES_PER_QUERY); - anomalyResultSetting.add(PAGE_SIZE); - anomalyResultSetting.add(MAX_RETRY_FOR_UNRESPONSIVE_NODE); - anomalyResultSetting.add(BACKOFF_MINUTES); + anomalyResultSetting.add(AD_MAX_ENTITIES_PER_QUERY); + anomalyResultSetting.add(AD_PAGE_SIZE); + anomalyResultSetting.add(AD_MAX_RETRY_FOR_UNRESPONSIVE_NODE); + anomalyResultSetting.add(AD_BACKOFF_MINUTES); ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, anomalyResultSetting); DiscoveryNode discoveryNode = new DiscoveryNode( @@ -246,16 +234,16 @@ public void setUp() throws Exception { indexNameResolver = new IndexNameExpressionResolver(new ThreadContext(Settings.EMPTY)); - adCircuitBreakerService = mock(ADCircuitBreakerService.class); + adCircuitBreakerService = mock(CircuitBreakerService.class); when(adCircuitBreakerService.isOpen()).thenReturn(false); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_REQUEST_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; adStats = new ADStats(statsMap); @@ -267,7 +255,7 @@ public void setUp() throws Exception { return null; }) .when(adTaskManager) - .initRealtimeTaskCacheAndCleanupStaleCache( + .initCacheWithCleanupIfRequired( anyString(), any(AnomalyDetector.class), any(TransportService.class), @@ -293,7 +281,7 @@ public void setUp() throws Exception { adTaskManager ); - provider = mock(CacheProvider.class); + provider = mock(EntityCacheProvider.class); entityCache = mock(EntityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.get(any(), any())) @@ -301,11 +289,11 @@ public void setUp() throws Exception { when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(new ArrayList(), new ArrayList())); indexUtil = mock(ADIndexManagement.class); - resultWriteQueue = mock(ResultWriteWorker.class); - checkpointReadQueue = mock(CheckpointReadWorker.class); - entityColdStartQueue = mock(EntityColdStartWorker.class); + resultWriteQueue = mock(ADResultWriteWorker.class); + checkpointReadQueue = mock(ADCheckpointReadWorker.class); + entityColdStartQueue = mock(ADColdStartWorker.class); - coldEntityQueue = mock(ColdEntityWorker.class); + coldEntityQueue = mock(ADColdEntityWorker.class); attrs1 = new HashMap<>(); attrs1.put(serviceField, app0); @@ -396,9 +384,9 @@ public String executor() { }; } - private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) { + private void setUpEntityResult(int nodeIndex, ADNodeStateManager nodeStateManager) { // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[nodeIndex].transportService, @@ -415,8 +403,7 @@ private void setUpEntityResult(int nodeIndex, NodeStateManager nodeStateManager) adStats ); - when(normalModelManager.getAnomalyResultForEntity(any(), any(), any(), any(), anyInt())) - .thenReturn(new ThresholdingResult(0, 1, 1)); + when(normalModelManager.getResult(any(), any(), any(), any(), anyInt())).thenReturn(new ThresholdingResult(0, 1, 1)); } private void setUpEntityResult(int nodeIndex) { @@ -431,18 +418,18 @@ public void setUpNormlaStateManager() throws IOException { .setCategoryFields(ImmutableList.of(randomAlphaOfLength(5))) .build(); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(1); + ActionListener listener = invocation.getArgument(2); listener.onResponse(TestHelpers.createGetResponse(detector, detectorId, CommonName.CONFIG_INDEX)); return null; }).when(client).get(any(GetRequest.class), any(ActionListener.class)); - stateManager = new NodeStateManager( + stateManager = new ADNodeStateManager( client, xContentRegistry(), settings, new ClientUtil(settings, client, new Throttler(mock(Clock.class)), threadPool), clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, clusterService ); @@ -537,11 +524,7 @@ public void testIndexNotFound() throws InterruptedException, IOException { PlainActionFuture listener2 = new PlainActionFuture<>(); action.doExecute(null, request, listener2); Exception e = expectThrows(EndRunException.class, () -> listener2.actionGet(10000L)); - assertThat( - "actual message: " + e.getMessage(), - e.getMessage(), - containsString(AnomalyResultTransportAction.TROUBLE_QUERYING_ERR_MSG) - ); + assertThat("actual message: " + e.getMessage(), e.getMessage(), containsString(ResultProcessor.TROUBLE_QUERYING_ERR_MSG)); assertTrue(!((EndRunException) e).isEndNow()); } @@ -644,7 +627,7 @@ private void setUpSearchResponse() throws IOException { private void setUpTransportInterceptor( Function, TransportResponseHandler> interceptor, - NodeStateManager nodeStateManager + ADNodeStateManager nodeStateManager ) { entityResultInterceptor = new TransportInterceptor() { @Override @@ -659,7 +642,7 @@ public void sendRequest( TransportRequestOptions options, TransportResponseHandler handler ) { - if (action.equals(EntityResultAction.NAME)) { + if (action.equals(EntityADResultAction.NAME)) { sender .sendRequest( connection, @@ -678,7 +661,7 @@ public void sendRequest( // we start support multi-category fields since 1.1 // Set version to 1.1 will force the outbound/inbound message to use 1.1 version - setupTestNodes(entityResultInterceptor, 5, settings, Version.V_2_0_0, MAX_ENTITIES_PER_QUERY, PAGE_SIZE); + setupTestNodes(entityResultInterceptor, 5, settings, Version.V_2_0_0, AD_MAX_ENTITIES_PER_QUERY, AD_PAGE_SIZE); TransportService realTransportService = testNodes[0].transportService; ClusterService realClusterService = testNodes[0].clusterService; @@ -713,7 +696,7 @@ public void testNonEmptyFeatures() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -747,29 +730,29 @@ public void testCircuitBreakerOpen() throws InterruptedException, IOException { return null; }).when(clientUtil).asyncRequest(any(GetRequest.class), any(), any(ActionListener.class)); - stateManager = new NodeStateManager( + stateManager = new ADNodeStateManager( client, xContentRegistry(), settings, clientUtil, clock, - AnomalyDetectorSettings.HOURLY_MAINTENANCE, + TimeSeriesSettings.HOURLY_MAINTENANCE, clusterService ); - NodeStateManager spyStateManager = spy(stateManager); + ADNodeStateManager spyStateManager = spy(stateManager); setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler, spyStateManager); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); - ADCircuitBreakerService openBreaker = mock(ADCircuitBreakerService.class); + CircuitBreakerService openBreaker = mock(CircuitBreakerService.class); when(openBreaker.isOpen()).thenReturn(true); // register entity result action - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -812,7 +795,7 @@ public void testNotAck() throws InterruptedException, IOException { setUpSearchResponse(); setUpTransportInterceptor(this::unackEntityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -843,13 +826,13 @@ public void testMultipleNode() throws InterruptedException, IOException { Entity entity3 = Entity.createEntityByReordering(attrs3); // we use ordered attributes values as the key to hashring - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity1.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity1.toString()))) .thenReturn(Optional.of(testNodes[2].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity2.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity2.toString()))) .thenReturn(Optional.of(testNodes[3].discoveryNode())); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(eq(entity3.toString()))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(eq(entity3.toString()))) .thenReturn(Optional.of(testNodes[4].discoveryNode())); for (int i = 2; i <= 4; i++) { @@ -879,7 +862,7 @@ public void testCacheSelectionError() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); setUpEntityResult(1); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -936,7 +919,7 @@ public boolean matches(List argument) { public void testCacheSelection() throws IOException, InterruptedException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); List hotEntities = new ArrayList<>(); @@ -947,13 +930,13 @@ public void testCacheSelection() throws IOException, InterruptedException { Entity entity2 = Entity.createEntityByReordering(attrs2); coldEntities.add(entity2); - provider = mock(CacheProvider.class); + provider = mock(EntityCacheProvider.class); entityCache = mock(EntityCache.class); when(provider.get()).thenReturn(entityCache); when(entityCache.selectUpdateCandidate(any(), any(), any())).thenReturn(Pair.of(hotEntities, coldEntities)); when(entityCache.get(any(), any())).thenReturn(null); - new EntityResultTransportAction( + new EntityADResultTransportAction( new ActionFilters(Collections.emptySet()), // since we send requests to testNodes[1] testNodes[1].transportService, @@ -1125,7 +1108,7 @@ public void testRetry() throws IOException, InterruptedException { }).when(coldEntityQueue).putAll(any()); setUpTransportInterceptor(this::entityResultHandler); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); setUpEntityResult(1); @@ -1197,14 +1180,14 @@ public void testEmptyPageToString() { } @SuppressWarnings("unchecked") - private NodeStateManager setUpTestExceptionTestingInModelNode() throws IOException { + private ADNodeStateManager setUpTestExceptionTestingInModelNode() throws IOException { setUpSearchResponse(); setUpTransportInterceptor(this::entityResultHandler); // mock hashing ring response. This has to happen after setting up test nodes with the failure interceptor - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); - NodeStateManager modelNodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager modelNodeStateManager = mock(ADNodeStateManager.class); CountDownLatch modelNodeInProgress = new CountDownLatch(1); // make sure parameters are not null, otherwise this mock won't get invoked doAnswer(invocation -> { @@ -1212,12 +1195,12 @@ private NodeStateManager setUpTestExceptionTestingInModelNode() throws IOExcepti listener.onResponse(Optional.of(detector)); modelNodeInProgress.countDown(); return null; - }).when(modelNodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + }).when(modelNodeStateManager).getConfig(anyString(), any(ActionListener.class)); return modelNodeStateManager; } public void testEndRunNowInModelNode() throws InterruptedException, IOException { - NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + ADNodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); CountDownLatch inProgress = new CountDownLatch(1); doAnswer(invocation -> { @@ -1262,7 +1245,7 @@ public void testEndRunNowInModelNode() throws InterruptedException, IOException } public void testEndRunNowFalseInModelNode() throws InterruptedException, IOException { - NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + ADNodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); when(modelNodeStateManager.fetchExceptionAndClear(anyString())) .thenReturn( @@ -1310,7 +1293,7 @@ public void testEndRunNowFalseInModelNode() throws InterruptedException, IOExcep * @throws InterruptedException when failing to wait for inProgress to finish */ public void testTimeOutExceptionInModelNode() throws IOException, InterruptedException { - NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + ADNodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); when(modelNodeStateManager.fetchExceptionAndClear(anyString())).thenReturn(Optional.of(new OpenSearchTimeoutException("blah"))); @@ -1348,7 +1331,7 @@ public void testTimeOutExceptionInModelNode() throws IOException, InterruptedExc public void testSelectHigherExceptionInModelNode() throws InterruptedException, IOException { when(entityCache.get(any(), any())).thenThrow(EndRunException.class); - NodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); + ADNodeStateManager modelNodeStateManager = setUpTestExceptionTestingInModelNode(); when(modelNodeStateManager.fetchExceptionAndClear(anyString())).thenReturn(Optional.of(new OpenSearchTimeoutException("blah"))); diff --git a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java index a60a350e6..369e14dd2 100644 --- a/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/PreviewAnomalyDetectorTransportActionTests.java @@ -24,7 +24,6 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -47,11 +46,6 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.WriteRequest; import org.opensearch.ad.AnomalyDetectorRunner; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.feature.Features; -import org.opensearch.ad.indices.ADIndexManagement; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.settings.AnomalyDetectorSettings; @@ -86,9 +80,9 @@ public class PreviewAnomalyDetectorTransportActionTests extends OpenSearchSingle private AnomalyDetectorRunner runner; private ClusterService clusterService; private FeatureManager featureManager; - private ModelManager modelManager; + private ADModelManager modelManager; private Task task; - private ADCircuitBreakerService circuitBreaker; + private CircuitBreakerService circuitBreaker; @Override @Before @@ -104,8 +98,8 @@ public void setUp() throws Exception { Arrays .asList( AnomalyDetectorSettings.MAX_ANOMALY_FEATURES, - AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES, - AnomalyDetectorSettings.PAGE_SIZE, + AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES, + AnomalyDetectorSettings.AD_PAGE_SIZE, AnomalyDetectorSettings.MAX_CONCURRENT_PREVIEW ) ) @@ -122,15 +116,17 @@ public void setUp() throws Exception { .build(); final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); IndexMetadata indexMetaData = IndexMetadata.builder(CommonName.CONFIG_INDEX).settings(existingSettings).build(); - final Map indices = new HashMap<>(); - indices.put(CommonName.CONFIG_INDEX, indexMetaData); + final ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(CommonName.CONFIG_INDEX, indexMetaData) + .build(); ClusterState clusterState = ClusterState.builder(clusterName).metadata(Metadata.builder().indices(indices).build()).build(); when(clusterService.state()).thenReturn(clusterState); featureManager = mock(FeatureManager.class); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); runner = new AnomalyDetectorRunner(modelManager, featureManager, AnomalyDetectorSettings.MAX_PREVIEW_RESULTS); - circuitBreaker = mock(ADCircuitBreakerService.class); + circuitBreaker = mock(CircuitBreakerService.class); when(circuitBreaker.isOpen()).thenReturn(false); action = new PreviewAnomalyDetectorTransportAction( Settings.EMPTY, @@ -278,7 +274,7 @@ public void onFailure(Exception e) { @Test public void testPreviewTransportActionNoContext() throws IOException, InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); - Settings settings = Settings.builder().put(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + Settings settings = Settings.builder().put(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES.getKey(), true).build(); Client client = mock(Client.class); ThreadContext threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "alice|odfe,aes|engineering,operations"); diff --git a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java index f522d89f1..9944476ea 100644 --- a/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/ProfileTransportActionTests.java @@ -29,19 +29,17 @@ import org.junit.Test; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; -import org.opensearch.ad.caching.CacheProvider; -import org.opensearch.ad.caching.EntityCache; -import org.opensearch.ad.feature.FeatureManager; -import org.opensearch.ad.ml.ModelManager; +import org.opensearch.ad.ml.ADModelManager; import org.opensearch.ad.model.DetectorProfileName; -import org.opensearch.ad.model.ModelProfile; import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.settings.Settings; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.timeseries.TimeSeriesAnalyticsPlugin; +import org.opensearch.timeseries.feature.FeatureManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.model.ModelProfile; import org.opensearch.transport.TransportService; public class ProfileTransportActionTests extends OpenSearchIntegTestCase { @@ -53,11 +51,11 @@ public class ProfileTransportActionTests extends OpenSearchIntegTestCase { private int shingleSize = 6; private long modelSize = 4456448L; private String modelId = "Pl536HEBnXkDrah03glg_model_rcf_1"; - private CacheProvider cacheProvider; + private HCCacheProvider cacheProvider; private int activeEntities = 10; private long totalUpdates = 127; private long multiEntityModelSize = 712480L; - private ModelManager modelManager; + private ADModelManager modelManager; private FeatureManager featureManager; @Override @@ -65,13 +63,13 @@ public class ProfileTransportActionTests extends OpenSearchIntegTestCase { public void setUp() throws Exception { super.setUp(); - modelManager = mock(ModelManager.class); + modelManager = mock(ADModelManager.class); featureManager = mock(FeatureManager.class); when(featureManager.getShingleSize(any(String.class))).thenReturn(shingleSize); EntityCache cache = mock(EntityCache.class); - cacheProvider = mock(CacheProvider.class); + cacheProvider = mock(HCCacheProvider.class); when(cacheProvider.get()).thenReturn(cache); when(cache.getActiveEntities(anyString())).thenReturn(activeEntities); when(cache.getTotalUpdates(anyString())).thenReturn(totalUpdates); @@ -114,7 +112,7 @@ public void setUp() throws Exception { } private void setUpModelSize(int maxModel) { - Settings nodeSettings = Settings.builder().put(AnomalyDetectorSettings.MAX_MODEL_SIZE_PER_NODE.getKey(), maxModel).build(); + Settings nodeSettings = Settings.builder().put(AnomalyDetectorSettings.AD_MAX_MODEL_SIZE_PER_NODE.getKey(), maxModel).build(); internalCluster().startNode(nodeSettings); } diff --git a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java index edb480dd1..80f0b5510 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFPollingTests.java @@ -29,10 +29,8 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.SingleStreamModelIdMapper; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; @@ -43,7 +41,9 @@ import org.opensearch.tasks.Task; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.cluster.HashRing; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.ml.SingleStreamModelIdMapper; import org.opensearch.transport.ConnectTransportException; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportException; @@ -69,7 +69,7 @@ public class RCFPollingTests extends AbstractTimeSeriesTest { private ClusterService clusterService; private HashRing hashRing; private TransportAddress transportAddress1; - private ModelManager manager; + private ADModelManager manager; private TransportService transportService; private PlainActionFuture future; private RCFPollingTransportAction action; @@ -104,7 +104,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); hashRing = mock(HashRing.class); transportAddress1 = new TransportAddress(new InetSocketAddress(InetAddress.getByName("1.2.3.4"), 9300)); - manager = mock(ModelManager.class); + manager = mock(ADModelManager.class); transportService = new TransportService( Settings.EMPTY, mock(Transport.class), @@ -189,7 +189,7 @@ public void testDoubleNaN() { public void testNormal() { DiscoveryNode localNode = new DiscoveryNode(nodeId, transportAddress1, Version.CURRENT.minimumCompatibilityVersion()); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.of(localNode)); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.of(localNode)); when(clusterService.localNode()).thenReturn(localNode); @@ -208,7 +208,7 @@ public void testNormal() { } public void testNoNodeFoundForModel() { - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))).thenReturn(Optional.empty()); + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))).thenReturn(Optional.empty()); action = new RCFPollingTransportAction( mock(ActionFilters.class), transportService, @@ -305,7 +305,7 @@ public void testGetRemoteNormalResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); @@ -333,7 +333,7 @@ public void testGetRemoteFailureResponse() { clusterService ); - when(hashRing.getOwningNodeWithSameLocalAdVersionForRealtimeAD(any(String.class))) + when(hashRing.getOwningNodeWithSameLocalVersionForRealtime(any(String.class))) .thenReturn(Optional.of(testNodes[1].discoveryNode())); registerHandler(testNodes[1]); diff --git a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java index 8f26af293..ee00583d8 100644 --- a/src/test/java/org/opensearch/ad/transport/RCFResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/RCFResultTests.java @@ -38,16 +38,11 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.PlainActionFuture; -import org.opensearch.ad.breaker.ADCircuitBreakerService; -import org.opensearch.ad.cluster.HashRing; import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ThresholdingResult; -import org.opensearch.ad.stats.ADStat; import org.opensearch.ad.stats.ADStats; -import org.opensearch.ad.stats.suppliers.CounterSupplier; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -59,6 +54,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.common.exception.LimitExceededException; import org.opensearch.timeseries.stats.StatNames; +import org.opensearch.timeseries.stats.suppliers.CounterSupplier; import org.opensearch.transport.Transport; import org.opensearch.transport.TransportService; @@ -88,10 +84,10 @@ public void setUp() throws Exception { hashRing = mock(HashRing.class); node = mock(DiscoveryNode.class); doReturn(Optional.of(node)).when(hashRing).getNodeByAddress(any()); - Map> statsMap = new HashMap>() { + Map> statsMap = new HashMap>() { { - put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); - put(StatNames.MODEL_CORRUTPION_COUNT.getName(), new ADStat<>(false, new CounterSupplier())); + put(StatNames.AD_HC_EXECUTE_FAIL_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); + put(StatNames.AD_MODEL_CORRUTPION_COUNT.getName(), new TimeSeriesStat<>(false, new CounterSupplier())); } }; @@ -110,8 +106,8 @@ public void testNormal() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); - ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + ADModelManager manager = mock(ADModelManager.class); + CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), transportService, @@ -168,8 +164,8 @@ public void testExecutionException() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); - ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + ADModelManager manager = mock(ADModelManager.class); + CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), transportService, @@ -284,8 +280,8 @@ public void testCircuitBreaker() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); - ADCircuitBreakerService breakerService = mock(ADCircuitBreakerService.class); + ADModelManager manager = mock(ADModelManager.class); + CircuitBreakerService breakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), transportService, @@ -335,8 +331,8 @@ public void testCorruptModel() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); - ADCircuitBreakerService adCircuitBreakerService = mock(ADCircuitBreakerService.class); + ADModelManager manager = mock(ADModelManager.class); + CircuitBreakerService adCircuitBreakerService = mock(CircuitBreakerService.class); RCFResultTransportAction action = new RCFResultTransportAction( mock(ActionFilters.class), transportService, @@ -359,7 +355,7 @@ public void testCorruptModel() { action.doExecute(mock(Task.class), request, future); expectThrows(IllegalArgumentException.class, () -> future.actionGet()); - Object val = adStats.getStat(StatNames.MODEL_CORRUTPION_COUNT.getName()).getValue(); + Object val = adStats.getStat(StatNames.AD_MODEL_CORRUTPION_COUNT.getName()).getValue(); assertEquals(1L, ((Long) val).longValue()); verify(manager, times(1)).clear(eq(detectorId), any()); } diff --git a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java index bc87faf13..65b0ee95d 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchADTasksTransportActionTests.java @@ -24,13 +24,13 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.ad.HistoricalAnalysisIntegTestCase; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.model.ADTask; import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.model.TimeSeriesTask; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) public class SearchADTasksTransportActionTests extends HistoricalAnalysisIntegTestCase { @@ -81,7 +81,7 @@ public void testSearchWithExistingTask() throws IOException { private SearchRequest searchRequest(boolean isLatest) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); BoolQueryBuilder query = new BoolQueryBuilder(); - query.filter(new TermQueryBuilder(ADTask.IS_LATEST_FIELD, isLatest)); + query.filter(new TermQueryBuilder(TimeSeriesTask.IS_LATEST_FIELD, isLatest)); sourceBuilder.query(query); SearchRequest request = new SearchRequest().source(sourceBuilder).indices(ADCommonName.DETECTION_STATE_INDEX); return request; diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java index b67ec6aec..07a823513 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyDetectorInfoActionTests.java @@ -92,7 +92,7 @@ public void onFailure(Exception e) { clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); } diff --git a/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java index ac902a55e..cf96ce70e 100644 --- a/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/SearchAnomalyResultActionTests.java @@ -87,7 +87,7 @@ public void setUp() throws Exception { clusterService = mock(ClusterService.class); ClusterSettings clusterSettings = new ClusterSettings( Settings.EMPTY, - Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES))) + Collections.unmodifiableSet(new HashSet<>(Arrays.asList(AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES))) ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); clusterState = createClusterState(); diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java index 796d492e1..4130ba7ef 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorActionTests.java @@ -21,7 +21,6 @@ import org.junit.Before; import org.junit.Test; import org.opensearch.action.FailedNodeException; -import org.opensearch.ad.stats.ADStatsResponse; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; diff --git a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java index 7c877c086..4fa5b11da 100644 --- a/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StatsAnomalyDetectorTransportActionTests.java @@ -65,14 +65,14 @@ public void testStatsAnomalyDetectorWithNodeLevelStats() { public void testStatsAnomalyDetectorWithClusterLevelStats() { ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); adStatsRequest.addStat(StatNames.DETECTOR_COUNT.getName()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithDetectorCount() { @@ -84,18 +84,18 @@ public void testStatsAnomalyDetectorWithDetectorCount() { Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); assertEquals(2L, clusterStats.get(StatNames.DETECTOR_COUNT.getName())); - assertFalse(clusterStats.containsKey(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertFalse(clusterStats.containsKey(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); } public void testStatsAnomalyDetectorWithSingleEntityDetectorCount() { ADStatsRequest adStatsRequest = new ADStatsRequest(clusterService().localNode()); - adStatsRequest.addStat(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName()); + adStatsRequest.addStat(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName()); StatsAnomalyDetectorResponse response = client().execute(StatsAnomalyDetectorAction.INSTANCE, adStatsRequest).actionGet(5_000); assertEquals(1, response.getAdStatsResponse().getADStatsNodesResponse().getNodes().size()); Map statsMap = response.getAdStatsResponse().getADStatsNodesResponse().getNodes().get(0).getStatsMap(); Map clusterStats = response.getAdStatsResponse().getClusterStats(); assertEquals(0, statsMap.size()); - assertEquals(1L, clusterStats.get(StatNames.SINGLE_ENTITY_DETECTOR_COUNT.getName())); + assertEquals(1L, clusterStats.get(StatNames.SINGLE_STREAM_DETECTOR_COUNT.getName())); assertFalse(clusterStats.containsKey(StatNames.DETECTOR_COUNT.getName())); } diff --git a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java index d6ed84d2d..e6edcc93e 100644 --- a/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java +++ b/src/test/java/org/opensearch/ad/transport/StopDetectorActionTests.java @@ -27,6 +27,8 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.timeseries.transport.StopConfigRequest; +import org.opensearch.timeseries.transport.StopConfigResponse; public class StopDetectorActionTests extends OpenSearchIntegTestCase { @@ -44,7 +46,7 @@ public void testStopDetectorAction() { @Test public void fromActionRequest_Success() { - StopDetectorRequest stopDetectorRequest = new StopDetectorRequest("adID"); + StopConfigRequest stopDetectorRequest = new StopConfigRequest("adID"); ActionRequest actionRequest = new ActionRequest() { @Override public ActionRequestValidationException validate() { @@ -56,41 +58,41 @@ public void writeTo(StreamOutput out) throws IOException { stopDetectorRequest.writeTo(out); } }; - StopDetectorRequest result = StopDetectorRequest.fromActionRequest(actionRequest); + StopConfigRequest result = StopConfigRequest.fromActionRequest(actionRequest); assertNotSame(result, stopDetectorRequest); - assertEquals(result.getAdID(), stopDetectorRequest.getAdID()); + assertEquals(result.getConfigID(), stopDetectorRequest.getConfigID()); } @Test public void writeTo_Success() throws IOException { BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); - StopDetectorResponse response = new StopDetectorResponse(true); + StopConfigResponse response = new StopConfigResponse(true); response.writeTo(bytesStreamOutput); - StopDetectorResponse parsedResponse = new StopDetectorResponse(bytesStreamOutput.bytes().streamInput()); + StopConfigResponse parsedResponse = new StopConfigResponse(bytesStreamOutput.bytes().streamInput()); assertNotEquals(response, parsedResponse); assertEquals(response.success(), parsedResponse.success()); } @Test public void fromActionResponse_Success() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); ActionResponse actionResponse = new ActionResponse() { @Override public void writeTo(StreamOutput streamOutput) throws IOException { stopDetectorResponse.writeTo(streamOutput); } }; - StopDetectorResponse result = stopDetectorResponse.fromActionResponse(actionResponse); + StopConfigResponse result = stopDetectorResponse.fromActionResponse(actionResponse); assertNotSame(result, stopDetectorResponse); assertEquals(result.success(), stopDetectorResponse.success()); - StopDetectorResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); + StopConfigResponse parsedStopDetectorResponse = stopDetectorResponse.fromActionResponse(stopDetectorResponse); assertEquals(parsedStopDetectorResponse, stopDetectorResponse); } @Test public void toXContentTest() throws IOException { - StopDetectorResponse stopDetectorResponse = new StopDetectorResponse(true); + StopConfigResponse stopDetectorResponse = new StopConfigResponse(true); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); stopDetectorResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); assertNotNull(builder); diff --git a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java index 9f2869c8c..bcaf05a7d 100644 --- a/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java +++ b/src/test/java/org/opensearch/ad/transport/ThresholdResultTests.java @@ -30,7 +30,6 @@ import org.opensearch.ad.common.exception.JsonPathNotFoundException; import org.opensearch.ad.constant.ADCommonMessages; import org.opensearch.ad.constant.ADCommonName; -import org.opensearch.ad.ml.ModelManager; import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.common.Strings; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -59,7 +58,7 @@ public void testNormal() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(3); @@ -88,7 +87,7 @@ public void testExecutionException() { Collections.emptySet() ); - ModelManager manager = mock(ModelManager.class); + ADModelManager manager = mock(ADModelManager.class); ThresholdResultTransportAction action = new ThresholdResultTransportAction(mock(ActionFilters.class), transportService, manager); doThrow(NullPointerException.class) .when(manager) diff --git a/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java index 793bdc7b9..a39f91895 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/ADSearchHandlerTests.java @@ -17,7 +17,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.ad.settings.AnomalyDetectorSettings.FILTER_BY_BACKEND_ROLES; +import static org.opensearch.ad.settings.AnomalyDetectorSettings.AD_FILTER_BY_BACKEND_ROLES; import static org.opensearch.timeseries.TestHelpers.matchAllRequest; import org.junit.Before; @@ -50,8 +50,8 @@ public class ADSearchHandlerTests extends ADUnitTestCase { @Override public void setUp() throws Exception { super.setUp(); - settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), false).build(); - clusterSettings = clusterSetting(settings, FILTER_BY_BACKEND_ROLES); + settings = Settings.builder().put(AD_FILTER_BY_BACKEND_ROLES.getKey(), false).build(); + clusterSettings = clusterSetting(settings, AD_FILTER_BY_BACKEND_ROLES); clusterService = new ClusterService(settings, clusterSettings, null); client = mock(Client.class); searchHandler = new ADSearchHandler(settings, clusterService, client); @@ -74,7 +74,7 @@ public void testSearchException() { } public void testFilterEnabledWithWrongSearch() { - settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + settings = Settings.builder().put(AD_FILTER_BY_BACKEND_ROLES.getKey(), true).build(); clusterService = new ClusterService(settings, clusterSettings, null); searchHandler = new ADSearchHandler(settings, clusterService, client); @@ -83,7 +83,7 @@ public void testFilterEnabledWithWrongSearch() { } public void testFilterEnabled() { - settings = Settings.builder().put(FILTER_BY_BACKEND_ROLES.getKey(), true).build(); + settings = Settings.builder().put(AD_FILTER_BY_BACKEND_ROLES.getKey(), true).build(); clusterService = new ClusterService(settings, clusterSettings, null); searchHandler = new ADSearchHandler(settings, clusterService, client); diff --git a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java index 4d1c1ed44..6eb2d4c70 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AbstractIndexHandlerTest.java @@ -30,9 +30,6 @@ import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.transport.AnomalyResultTests; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -43,6 +40,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.AbstractTimeSeriesTest; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.IndexUtils; public abstract class AbstractIndexHandlerTest extends AbstractTimeSeriesTest { enum IndexCreation { diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java index a2635ed8f..b26887145 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultBulkIndexHandlerTests.java @@ -37,9 +37,6 @@ import org.opensearch.ad.ADUnitTestCase; import org.opensearch.ad.indices.ADIndexManagement; import org.opensearch.ad.model.AnomalyResult; -import org.opensearch.ad.util.ClientUtil; -import org.opensearch.ad.util.IndexUtils; -import org.opensearch.ad.util.Throttler; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -48,12 +45,15 @@ import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler; +import org.opensearch.timeseries.util.ClientUtil; +import org.opensearch.timeseries.util.IndexUtils; import com.google.common.collect.ImmutableList; public class AnomalyResultBulkIndexHandlerTests extends ADUnitTestCase { - private AnomalyResultBulkIndexHandler bulkIndexHandler; + private ResultBulkIndexingHandler bulkIndexHandler; private Client client; private IndexUtils indexUtils; private ActionListener listener; @@ -72,7 +72,7 @@ public void setUp() throws Exception { indexUtils = mock(IndexUtils.class); ClusterService clusterService = mock(ClusterService.class); ThreadPool threadPool = mock(ThreadPool.class); - bulkIndexHandler = new AnomalyResultBulkIndexHandler( + bulkIndexHandler = new ResultBulkIndexingHandler( client, settings, threadPool, diff --git a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java index 89367a72b..9ae4ddcf7 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/AnomalyResultHandlerTests.java @@ -34,7 +34,7 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; -import org.opensearch.ad.NodeStateManager; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.common.settings.Settings; @@ -42,10 +42,11 @@ import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException; import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.transport.handler.ResultIndexingHandler; public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Mock - private NodeStateManager nodeStateManager; + private ADNodeStateManager nodeStateManager; @Mock private Clock clock; @@ -54,7 +55,7 @@ public class AnomalyResultHandlerTests extends AbstractIndexHandlerTest { @Before public void setUp() throws Exception { super.setUp(); - super.setUpLog4jForJUnit(AnomalyIndexHandler.class); + super.setUpLog4jForJUnit(ResultIndexingHandler.class); } @Override @@ -81,7 +82,7 @@ public void testSavingAdResult() throws IOException { listener.onResponse(mock(IndexResponse.class)); return null; }).when(client).index(any(IndexRequest.class), ArgumentMatchers.>any()); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler( client, settings, threadPool, @@ -92,16 +93,16 @@ public void testSavingAdResult() throws IOException { clusterService ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testSavingFailureNotRetry() throws InterruptedException, IOException { savingFailureTemplate(false, 1, true); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); } @Test @@ -109,15 +110,15 @@ public void testSavingFailureRetry() throws InterruptedException, IOException { setWriteBlockAdResultIndex(false); savingFailureTemplate(true, 3, true); - assertEquals(2, testAppender.countMessage(AnomalyIndexHandler.RETRY_SAVING_ERR_MSG, true)); - assertEquals(1, testAppender.countMessage(AnomalyIndexHandler.FAIL_TO_SAVE_ERR_MSG, true)); - assertTrue(!testAppender.containsMessage(AnomalyIndexHandler.SUCCESS_SAVING_MSG, true)); + assertEquals(2, testAppender.countMessage(ResultIndexingHandler.RETRY_SAVING_ERR_MSG, true)); + assertEquals(1, testAppender.countMessage(ResultIndexingHandler.FAIL_TO_SAVE_ERR_MSG, true)); + assertTrue(!testAppender.containsMessage(ResultIndexingHandler.SUCCESS_SAVING_MSG, true)); } @Test public void testIndexWriteBlock() { setWriteBlockAdResultIndex(true); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler( client, settings, threadPool, @@ -129,13 +130,13 @@ public void testIndexWriteBlock() { ); handler.index(TestHelpers.randomAnomalyDetectResult(), detectorId, null); - assertTrue(testAppender.containsMessage(AnomalyIndexHandler.CANNOT_SAVE_ERR_MSG, true)); + assertTrue(testAppender.containsMessage(ResultIndexingHandler.CANNOT_SAVE_ERR_MSG, true)); } @Test public void testAdResultIndexExist() throws IOException { setUpSavingAnomalyResultIndex(false, IndexCreation.RESOURCE_EXISTS_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler( client, settings, threadPool, @@ -155,7 +156,7 @@ public void testAdResultIndexOtherException() throws IOException { expectedEx.expectMessage("Error in saving .opendistro-anomaly-results for detector " + detectorId); setUpSavingAnomalyResultIndex(false, IndexCreation.RUNTIME_EXCEPTION); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler( client, settings, threadPool, @@ -213,7 +214,7 @@ private void savingFailureTemplate(boolean throwOpenSearchRejectedExecutionExcep .put("plugins.anomaly_detection.backoff_initial_delay", TimeValue.timeValueMillis(1)) .build(); - AnomalyIndexHandler handler = new AnomalyIndexHandler( + ResultIndexingHandler handler = new ResultIndexingHandler( client, backoffSettings, threadPool, diff --git a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java index 4c8446577..061d29c84 100644 --- a/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java +++ b/src/test/java/org/opensearch/ad/transport/handler/MultiEntityResultHandlerTests.java @@ -24,24 +24,22 @@ import org.junit.Test; import org.mockito.ArgumentMatchers; import org.opensearch.action.ActionListener; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; -import org.opensearch.ad.transport.ADResultBulkAction; -import org.opensearch.ad.transport.ADResultBulkRequest; -import org.opensearch.ad.transport.ADResultBulkResponse; -import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.constant.CommonMessages; +import org.opensearch.timeseries.ratelimit.RequestPriority; +import org.opensearch.timeseries.ratelimit.ResultWriteRequest; +import org.opensearch.timeseries.transport.ResultBulkResponse; public class MultiEntityResultHandlerTests extends AbstractIndexHandlerTest { - private MultiEntityResultHandler handler; + private ADIndexMemoryPressureAwareResultHandler handler; private ADResultBulkRequest request; - private ADResultBulkResponse response; + private ResultBulkResponse response; @Override public void setUp() throws Exception { super.setUp(); - handler = new MultiEntityResultHandler( + handler = new ADIndexMemoryPressureAwareResultHandler( client, settings, threadPool, @@ -61,15 +59,15 @@ public void setUp() throws Exception { ); request.add(resultWriteRequest); - response = new ADResultBulkResponse(); + response = new ResultBulkResponse(); - super.setUpLog4jForJUnit(MultiEntityResultHandler.class); + super.setUpLog4jForJUnit(ADIndexMemoryPressureAwareResultHandler.class); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); } @Override @@ -89,10 +87,7 @@ public void testIndexWriteBlock() throws InterruptedException { verified.countDown(); }, exception -> { assertTrue(exception instanceof TimeSeriesException); - assertTrue( - "actual: " + exception.getMessage(), - exception.getMessage().contains(MultiEntityResultHandler.CANNOT_SAVE_RESULT_ERR_MSG) - ); + assertTrue("actual: " + exception.getMessage(), exception.getMessage().contains(CommonMessages.CANNOT_SAVE_RESULT_ERR_MSG)); verified.countDown(); })); @@ -109,17 +104,17 @@ public void testSavingAdResult() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test public void testSavingFailure() throws IOException, InterruptedException { setUpSavingAnomalyResultIndex(false); doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); + ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException()); return null; - }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); + }).when(client).execute(eq(ADResultBulkAction.INSTANCE), any(), ArgumentMatchers.>any()); CountDownLatch verified = new CountDownLatch(1); handler.flush(request, ActionListener.wrap(response -> { @@ -142,7 +137,7 @@ public void testAdResultIndexExists() throws IOException, InterruptedException { verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } @Test @@ -200,6 +195,6 @@ public void testCreateResourcExistsException() throws IOException, InterruptedEx verified.countDown(); })); assertTrue(verified.await(100, TimeUnit.SECONDS)); - assertEquals(1, testAppender.countMessage(MultiEntityResultHandler.SUCCESS_SAVING_RESULT_MSG, false)); + assertEquals(1, testAppender.countMessage(CommonMessages.SUCCESS_SAVING_RESULT_MSG, false)); } } diff --git a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java index aadc2d999..5a5e35e81 100644 --- a/src/test/java/org/opensearch/ad/util/BulkUtilTests.java +++ b/src/test/java/org/opensearch/ad/util/BulkUtilTests.java @@ -25,6 +25,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.BulkUtil; public class BulkUtilTests extends OpenSearchTestCase { public void testGetFailedIndexRequest() { diff --git a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java index 593445b01..0a5a1fb40 100644 --- a/src/test/java/org/opensearch/ad/util/DateUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/DateUtilsTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.DateUtils; public class DateUtilsTests extends OpenSearchTestCase { public void testDuration() { diff --git a/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java index 8d64ba08e..3a9ff1047 100644 --- a/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ExceptionUtilsTests.java @@ -18,6 +18,7 @@ import org.opensearch.core.rest.RestStatus; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.timeseries.common.exception.TimeSeriesException; +import org.opensearch.timeseries.util.ExceptionUtil; public class ExceptionUtilsTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java index bea6abf95..c385c25d7 100644 --- a/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/IndexUtilsTests.java @@ -24,6 +24,7 @@ import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.timeseries.TestHelpers; +import org.opensearch.timeseries.util.IndexUtils; public class IndexUtilsTests extends OpenSearchIntegTestCase { diff --git a/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java index b905ce623..95ba7a999 100644 --- a/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java +++ b/src/test/java/org/opensearch/ad/util/MultiResponsesDelegateActionListenerTests.java @@ -24,6 +24,7 @@ import org.opensearch.ad.model.DetectorProfile; import org.opensearch.ad.model.EntityAnomalyResult; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.timeseries.util.MultiResponsesDelegateActionListener; public class MultiResponsesDelegateActionListenerTests extends OpenSearchTestCase { diff --git a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java index c2dd673b4..af919c1cd 100644 --- a/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/ad/util/ParseUtilsTests.java @@ -11,7 +11,6 @@ package org.opensearch.ad.util; -import static org.opensearch.timeseries.util.ParseUtils.addUserBackendRolesFilter; import static org.opensearch.timeseries.util.ParseUtils.isAdmin; import java.io.IOException; @@ -127,16 +126,17 @@ public void testGenerateInternalFeatureQuery() throws IOException { public void testAddUserRoleFilterWithNullUser() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter(null, searchSourceBuilder); + ParseUtils.addUserBackendRolesFilter(null, searchSourceBuilder); assertEquals("{}", searchSourceBuilder.toString()); } public void testAddUserRoleFilterWithNullUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User(randomAlphaOfLength(5), null, ImmutableList.of(randomAlphaOfLength(5)), ImmutableList.of(randomAlphaOfLength(5))), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -147,15 +147,16 @@ public void testAddUserRoleFilterWithNullUserBackendRole() { public void testAddUserRoleFilterWithEmptyUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":[]," + "\"boost\":1.0}},\"path\":\"user\",\"ignore_unmapped\":false,\"score_mode\":\"none\",\"boost\":1.0}}]," @@ -168,15 +169,16 @@ public void testAddUserRoleFilterWithNormalUserBackendRole() { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); String backendRole1 = randomAlphaOfLength(5); String backendRole2 = randomAlphaOfLength(5); - addUserBackendRolesFilter( - new User( - randomAlphaOfLength(5), - ImmutableList.of(backendRole1, backendRole2), - ImmutableList.of(randomAlphaOfLength(5)), - ImmutableList.of(randomAlphaOfLength(5)) - ), - searchSourceBuilder - ); + ParseUtils + .addUserBackendRolesFilter( + new User( + randomAlphaOfLength(5), + ImmutableList.of(backendRole1, backendRole2), + ImmutableList.of(randomAlphaOfLength(5)), + ImmutableList.of(randomAlphaOfLength(5)) + ), + searchSourceBuilder + ); assertEquals( "{\"query\":{\"bool\":{\"must\":[{\"nested\":{\"query\":{\"terms\":{\"user.backend_roles.keyword\":" + "[\"" diff --git a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java index dda3a8761..2e67a961e 100644 --- a/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java +++ b/src/test/java/org/opensearch/forecast/settings/ForecastEnabledSettingTests.java @@ -20,11 +20,4 @@ public void testIsForecastBreakerEnabled() { ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_BREAKER_ENABLED, false); assertTrue(!ForecastEnabledSetting.isForecastBreakerEnabled()); } - - public void testIsDoorKeeperInCacheEnabled() { - assertTrue(!ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - ForecastEnabledSetting.getInstance().setSettingValue(ForecastEnabledSetting.FORECAST_DOOR_KEEPER_IN_CACHE_ENABLED, true); - assertTrue(ForecastEnabledSetting.isDoorKeeperInCacheEnabled()); - } - } diff --git a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java index 7d9f9b1b2..2a9423ddb 100644 --- a/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java +++ b/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityProfileTests.java @@ -29,16 +29,14 @@ import org.opensearch.action.get.GetResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.ad.ADNodeStateManager; import org.opensearch.ad.AbstractProfileRunnerTests; import org.opensearch.ad.AnomalyDetectorProfileRunner; -import org.opensearch.ad.NodeStateManager; import org.opensearch.ad.constant.ADCommonName; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.transport.ProfileAction; import org.opensearch.ad.transport.ProfileNodeResponse; import org.opensearch.ad.transport.ProfileResponse; -import org.opensearch.ad.util.SecurityClientUtil; import org.opensearch.cluster.ClusterName; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; @@ -47,6 +45,7 @@ import org.opensearch.timeseries.TestHelpers; import org.opensearch.timeseries.constant.CommonName; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; import com.carrotsearch.hppc.BitMixer; @@ -71,12 +70,12 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus throws IOException { detector = TestHelpers .randomAnomalyDetectorWithInterval(new IntervalTimeConfiguration(detectorIntervalMin, ChronoUnit.MINUTES), true); - NodeStateManager nodeStateManager = mock(NodeStateManager.class); + ADNodeStateManager nodeStateManager = mock(ADNodeStateManager.class); doAnswer(invocation -> { ActionListener> listener = invocation.getArgument(1); listener.onResponse(Optional.of(detector)); return null; - }).when(nodeStateManager).getAnomalyDetector(anyString(), any(ActionListener.class)); + }).when(nodeStateManager).getConfig(anyString(), any(ActionListener.class)); clientUtil = new SecurityClientUtil(nodeStateManager, Settings.EMPTY); runner = new AnomalyDetectorProfileRunner( client, @@ -103,7 +102,7 @@ private void setUpMultiEntityClientGet(DetectorStatus detectorStatus, JobStatus break; } } else if (request.index().equals(CommonName.JOB_INDEX)) { - AnomalyDetectorJob job = null; + Job job = null; switch (jobStatus) { case ENABLED: job = TestHelpers.randomAnomalyDetectorJob(true); diff --git a/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java index 8799b9be6..e2e2e2a76 100644 --- a/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java +++ b/src/test/java/org/opensearch/timeseries/AbstractTimeSeriesTest.java @@ -44,7 +44,6 @@ import org.opensearch.action.ActionResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ad.model.AnomalyDetector; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.cluster.metadata.AliasMetadata; @@ -64,6 +63,7 @@ import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.timeseries.model.Job; import org.opensearch.transport.TransportInterceptor; import org.opensearch.transport.TransportService; @@ -351,7 +351,7 @@ protected NamedXContentRegistry xContentRegistry() { AnomalyDetector.XCONTENT_REGISTRY, AnomalyResult.XCONTENT_REGISTRY, DetectorInternalState.XCONTENT_REGISTRY, - AnomalyDetectorJob.XCONTENT_REGISTRY + Job.XCONTENT_REGISTRY ) ); return new NamedXContentRegistry(entries); diff --git a/src/test/java/org/opensearch/timeseries/TestHelpers.java b/src/test/java/org/opensearch/timeseries/TestHelpers.java index 33ceb54fa..f33a3d469 100644 --- a/src/test/java/org/opensearch/timeseries/TestHelpers.java +++ b/src/test/java/org/opensearch/timeseries/TestHelpers.java @@ -67,18 +67,14 @@ import org.opensearch.ad.ml.ThresholdingResult; import org.opensearch.ad.mock.model.MockSimpleLog; import org.opensearch.ad.model.ADTask; -import org.opensearch.ad.model.ADTaskState; import org.opensearch.ad.model.ADTaskType; import org.opensearch.ad.model.AnomalyDetector; import org.opensearch.ad.model.AnomalyDetectorExecutionInput; -import org.opensearch.ad.model.AnomalyDetectorJob; import org.opensearch.ad.model.AnomalyResult; import org.opensearch.ad.model.AnomalyResultBucket; import org.opensearch.ad.model.DetectorInternalState; import org.opensearch.ad.model.DetectorValidationIssue; import org.opensearch.ad.model.ExpectedValueList; -import org.opensearch.ad.ratelimit.RequestPriority; -import org.opensearch.ad.ratelimit.ResultWriteRequest; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; import org.opensearch.client.Request; @@ -143,6 +139,7 @@ import org.opensearch.timeseries.model.Feature; import org.opensearch.timeseries.model.FeatureData; import org.opensearch.timeseries.model.IntervalTimeConfiguration; +import org.opensearch.timeseries.model.Job; import org.opensearch.timeseries.model.TimeConfiguration; import org.opensearch.timeseries.model.ValidationAspect; import org.opensearch.timeseries.model.ValidationIssueType; @@ -160,12 +157,12 @@ public class TestHelpers { public static final String AD_BASE_PREVIEW_URI = AD_BASE_DETECTORS_URI + "/%s/_preview"; public static final String AD_BASE_STATS_URI = "/_plugins/_anomaly_detection/stats"; public static ImmutableSet HISTORICAL_ANALYSIS_RUNNING_STATS = ImmutableSet - .of(ADTaskState.CREATED.name(), ADTaskState.INIT.name(), ADTaskState.RUNNING.name()); + .of(TaskState.CREATED.name(), TaskState.INIT.name(), TaskState.RUNNING.name()); // Task may fail if memory circuit breaker triggered. public static final Set HISTORICAL_ANALYSIS_FINISHED_FAILED_STATS = ImmutableSet - .of(ADTaskState.FINISHED.name(), ADTaskState.FAILED.name()); + .of(TaskState.FINISHED.name(), TaskState.FAILED.name()); public static ImmutableSet HISTORICAL_ANALYSIS_DONE_STATS = ImmutableSet - .of(ADTaskState.FAILED.name(), ADTaskState.FINISHED.name(), ADTaskState.STOPPED.name()); + .of(TaskState.FAILED.name(), TaskState.FINISHED.name(), TaskState.STOPPED.name()); private static final Logger logger = LogManager.getLogger(TestHelpers.class); public static final Random random = new Random(42); @@ -963,12 +960,12 @@ public static AnomalyResult randomHCADAnomalyDetectResult( ); } - public static AnomalyDetectorJob randomAnomalyDetectorJob() { + public static Job randomAnomalyDetectorJob() { return randomAnomalyDetectorJob(true); } - public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled, Instant enabledTime, Instant disabledTime) { - return new AnomalyDetectorJob( + public static Job randomAnomalyDetectorJob(boolean enabled, Instant enabledTime, Instant disabledTime) { + return new Job( randomAlphaOfLength(10), randomIntervalSchedule(), randomIntervalTimeConfiguration(), @@ -982,7 +979,7 @@ public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled, Insta ); } - public static AnomalyDetectorJob randomAnomalyDetectorJob(boolean enabled) { + public static Job randomAnomalyDetectorJob(boolean enabled) { return randomAnomalyDetectorJob( enabled, Instant.now().truncatedTo(ChronoUnit.SECONDS), @@ -1261,7 +1258,7 @@ public static Map mappings = new HashMap<>(); - - mappings - .put( + ImmutableOpenMap immutableOpenMap = ImmutableOpenMap + .builder() + .fPut( CommonName.JOB_INDEX, IndexMetadata .builder("test") diff --git a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java index 3eb4fa80a..f265b3ff0 100644 --- a/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java +++ b/src/test/java/test/org/opensearch/ad/util/ClusterCreation.java @@ -106,4 +106,36 @@ public static ClusterState state(int numDataNodes) { } return state(new ClusterName("test"), clusterManagerNode, clusterManagerNode, allNodes); } + + public static void main(String args[]) { + long start = System.currentTimeMillis(); + boolean condition = true; + int index = 1; + int getCutValue = 2; + int getCutDimension = 3; + String leftBox = "abcccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc"; + String rightBox = "deffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"; + try { + for (int i = 0; i < 10000; i++) { + if (condition) { + throw new IllegalStateException( + " incorrect bounding state at index " + + index + + " cut value " + + getCutValue + + "cut dimension " + + getCutDimension + + " left Box " + + leftBox.toString() + + " right box " + + rightBox.toString() + ); + } + } + + } finally { + long finish = System.currentTimeMillis(); + System.out.println(finish - start); + } + } } diff --git a/src/test/java/test/org/opensearch/ad/util/MLUtil.java b/src/test/java/test/org/opensearch/ad/util/MLUtil.java index babae59ef..c5eff9f0c 100644 --- a/src/test/java/test/org/opensearch/ad/util/MLUtil.java +++ b/src/test/java/test/org/opensearch/ad/util/MLUtil.java @@ -21,12 +21,10 @@ import java.util.Random; import java.util.stream.IntStream; -import org.opensearch.ad.ml.EntityModel; -import org.opensearch.ad.ml.ModelManager.ModelType; -import org.opensearch.ad.ml.ModelState; -import org.opensearch.ad.settings.AnomalyDetectorSettings; import org.opensearch.common.collect.Tuple; +import org.opensearch.timeseries.ml.ModelManager; import org.opensearch.timeseries.model.Entity; +import org.opensearch.timeseries.settings.TimeSeriesSettings; import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest; @@ -38,7 +36,7 @@ */ public class MLUtil { private static Random random = new Random(42); - private static int minSampleSize = AnomalyDetectorSettings.NUM_MIN_SAMPLES; + private static int minSampleSize = TimeSeriesSettings.NUM_MIN_SAMPLES; private static String randomString(int targetStringLength) { int leftLimit = 97; // letter 'a' @@ -58,7 +56,7 @@ public static Queue createQueueSamples(int size) { return res; } - public static ModelState randomModelState(RandomModelStateConfig config) { + public static ADModelState> randomModelState(RandomModelStateConfig config) { boolean fullModel = config.getFullModel() != null && config.getFullModel().booleanValue() ? true : false; float priority = config.getPriority() != null ? config.getPriority() : random.nextFloat(); String detectorId = config.getId() != null ? config.getId() : randomString(15); @@ -74,49 +72,53 @@ public static ModelState randomModelState(RandomModelStateConfig co } else { entity = Entity.createSingleAttributeEntity("", ""); } - EntityModel model = null; + createFromValueOnlySamples model = null; if (fullModel) { model = createNonEmptyModel(detectorId, sampleSize, entity); } else { model = createEmptyModel(entity, sampleSize); } - return new ModelState<>(model, detectorId, detectorId, ModelType.ENTITY.getName(), clock, priority); + return new ADModelState<>(model, detectorId, detectorId, ModelManager.ModelType.ENTITY.getName(), clock, priority); } - public static EntityModel createEmptyModel(Entity entity, int sampleSize) { + public static createFromValueOnlySamples createEmptyModel(Entity entity, int sampleSize) { Queue samples = createQueueSamples(sampleSize); - return new EntityModel(entity, samples, null); + return new createFromValueOnlySamples(entity, samples, null); } - public static EntityModel createEmptyModel(Entity entity) { + public static createFromValueOnlySamples createEmptyModel(Entity entity) { return createEmptyModel(entity, random.nextInt(minSampleSize)); } - public static EntityModel createNonEmptyModel(String detectorId, int sampleSize, Entity entity) { + public static createFromValueOnlySamples createNonEmptyModel( + String detectorId, + int sampleSize, + Entity entity + ) { Queue samples = createQueueSamples(sampleSize); - int numDataPoints = random.nextInt(1000) + AnomalyDetectorSettings.NUM_MIN_SAMPLES; + int numDataPoints = random.nextInt(1000) + TimeSeriesSettings.NUM_MIN_SAMPLES; ThresholdedRandomCutForest trcf = new ThresholdedRandomCutForest( ThresholdedRandomCutForest .builder() .dimensions(1) - .sampleSize(AnomalyDetectorSettings.NUM_SAMPLES_PER_TREE) - .numberOfTrees(AnomalyDetectorSettings.NUM_TREES) - .timeDecay(AnomalyDetectorSettings.TIME_DECAY) - .outputAfter(AnomalyDetectorSettings.NUM_MIN_SAMPLES) + .sampleSize(TimeSeriesSettings.NUM_SAMPLES_PER_TREE) + .numberOfTrees(TimeSeriesSettings.NUM_TREES) + .timeDecay(TimeSeriesSettings.TIME_DECAY) + .outputAfter(TimeSeriesSettings.NUM_MIN_SAMPLES) .initialAcceptFraction(0.125d) .parallelExecutionEnabled(false) .internalShinglingEnabled(true) - .anomalyRate(1 - AnomalyDetectorSettings.THRESHOLD_MIN_PVALUE) + .anomalyRate(1 - TimeSeriesSettings.THRESHOLD_MIN_PVALUE) ); for (int i = 0; i < numDataPoints; i++) { trcf.process(new double[] { random.nextDouble() }, i); } - EntityModel entityModel = new EntityModel(entity, samples, trcf); + createFromValueOnlySamples entityModel = new createFromValueOnlySamples<>(entity, samples, trcf); return entityModel; } - public static EntityModel createNonEmptyModel(String detectorId) { + public static createFromValueOnlySamples createNonEmptyModel(String detectorId) { return createNonEmptyModel(detectorId, random.nextInt(minSampleSize), Entity.createSingleAttributeEntity("", "")); }