Skip to content

Commit

Permalink
Upgrade rcf to 4.0
Browse files Browse the repository at this point in the history
This PR upgrades rcf to 4.0 as it has bug fixes and support for streaming imputation mode.

Testing done:
1. gradle build

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Mar 20, 2024
1 parent 1507dd4 commit 4445117
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 55 deletions.
9 changes: 6 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ dependencies {
implementation group: 'com.yahoo.datasketches', name: 'memory', version: '0.12.2'
implementation group: 'commons-lang', name: 'commons-lang', version: '2.6'
implementation group: 'org.apache.commons', name: 'commons-pool2', version: '2.11.1'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:3.8.0'
implementation 'software.amazon.randomcutforest:randomcutforest-serialization:4.0.0'
implementation 'software.amazon.randomcutforest:randomcutforest-parkservices:4.0.0'
implementation 'software.amazon.randomcutforest:randomcutforest-core:4.0.0'

// we inherit jackson-core from opensearch core
implementation "com.fasterxml.jackson.core:jackson-databind:2.16.1"
Expand All @@ -149,6 +149,9 @@ dependencies {
exclude group: 'org.ow2.asm', module: 'asm-tree'
}

// used for output encoding of config descriptions
implementation group: 'org.owasp.encoder' , name: 'encoder', version: '1.2.3'

testImplementation group: 'pl.pragmatists', name: 'JUnitParams', version: '1.1.1'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '5.9.0'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/ad/ml/CheckpointDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ private Optional<ThresholdedRandomCutForest> convertToTRCF(Optional<RandomCutFor
if (kllThreshold.isPresent()) {
scores = kllThreshold.get().extractScores();
}
return Optional.of(new ThresholdedRandomCutForest(rcf.get(), anomalyRate, scores));
// last parameter is lastShingledInput. Since we don't know it, use all 0 double array
return Optional.of(new ThresholdedRandomCutForest(rcf.get(), anomalyRate, scores, new double[rcf.get().getDimensions()]));
}

/**
Expand Down
26 changes: 13 additions & 13 deletions src/main/java/org/opensearch/ad/task/ADTaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,7 @@
import static org.opensearch.ad.constant.ADCommonMessages.NO_ELIGIBLE_NODE_TO_RUN_DETECTOR;
import static org.opensearch.ad.constant.ADCommonName.DETECTION_STATE_INDEX;
import static org.opensearch.ad.indices.ADIndexManagement.ALL_AD_RESULTS_INDEX_PATTERN;
import static org.opensearch.ad.model.ADTask.COORDINATING_NODE_FIELD;
import static org.opensearch.ad.model.ADTask.DETECTOR_ID_FIELD;
import static org.opensearch.ad.model.ADTask.ERROR_FIELD;
import static org.opensearch.ad.model.ADTask.ESTIMATED_MINUTES_LEFT_FIELD;
import static org.opensearch.ad.model.ADTask.EXECUTION_END_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.EXECUTION_START_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.INIT_PROGRESS_FIELD;
import static org.opensearch.ad.model.ADTask.IS_LATEST_FIELD;
import static org.opensearch.ad.model.ADTask.LAST_UPDATE_TIME_FIELD;
import static org.opensearch.ad.model.ADTask.PARENT_TASK_ID_FIELD;
import static org.opensearch.ad.model.ADTask.STATE_FIELD;
import static org.opensearch.ad.model.ADTask.STOPPED_BY_FIELD;
import static org.opensearch.ad.model.ADTask.TASK_PROGRESS_FIELD;
import static org.opensearch.ad.model.ADTask.TASK_TYPE_FIELD;
import static org.opensearch.ad.model.ADTaskType.ALL_HISTORICAL_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.HISTORICAL_DETECTOR_TASK_TYPES;
import static org.opensearch.ad.model.ADTaskType.REALTIME_TASK_TYPES;
Expand All @@ -52,6 +39,19 @@
import static org.opensearch.timeseries.constant.CommonName.TASK_ID_FIELD;
import static org.opensearch.timeseries.model.TaskState.NOT_ENDED_STATES;
import static org.opensearch.timeseries.model.TaskType.taskTypeToString;
import static org.opensearch.timeseries.model.TimeSeriesTask.COORDINATING_NODE_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.ERROR_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.ESTIMATED_MINUTES_LEFT_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.EXECUTION_END_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.EXECUTION_START_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.INIT_PROGRESS_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.IS_LATEST_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.LAST_UPDATE_TIME_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.PARENT_TASK_ID_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.STATE_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.STOPPED_BY_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_PROGRESS_FIELD;
import static org.opensearch.timeseries.model.TimeSeriesTask.TASK_TYPE_FIELD;
import static org.opensearch.timeseries.settings.TimeSeriesSettings.NUM_MIN_SAMPLES;
import static org.opensearch.timeseries.util.ExceptionUtil.getErrorMessage;
import static org.opensearch.timeseries.util.ExceptionUtil.getShardsFailure;
Expand Down
59 changes: 28 additions & 31 deletions src/test/java/org/opensearch/ad/ml/CheckpointDaoTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -1067,27 +1067,22 @@ public void testDeserializeTRCFModel() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(4.814651669367903);
scores.add(5.566968073093689);
scores.add(5.919907610660049);
scores.add(5.770278090352401);
scores.add(5.319779117320102);

List<Double> grade = new ArrayList<>();
grade.add(1.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
grade.add(0.0);
scores.add(5.052069275347555);
scores.add(6.117465704461799);
scores.add(6.6401649744661055);
scores.add(6.918514609476484);
scores.add(6.928318158276434);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down Expand Up @@ -1133,21 +1128,22 @@ public void testDeserialize_rcf3_rc3_single_stream_model() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data on RCF4.0. RCF4.0 changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(3.3830441158587066);
scores.add(2.825961659490065);
scores.add(2.4685871670647384);
scores.add(2.3123460886413647);
scores.add(2.1401987653477135);
scores.add(3.678754481587072);
scores.add(3.6809634269790252);
scores.add(3.683659822587799);
scores.add(3.6852688612219646);
scores.add(3.6859330728661064);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down Expand Up @@ -1190,21 +1186,22 @@ public void testDeserialize_rcf3_rc3_hc_model() throws Exception {
coldStartData.add(sample4);
coldStartData.add(sample5);

// This scores were generated with the sample data but on RCF3.0-rc1 and we are comparing them
// to the scores generated by the imported RCF3.0-rc2.1
// This scores were generated with the sample data but on RCF4.0 that changed implementation
// and we are seeing different rcf scores between 4.0 and 3.8. This is verified by switching
// rcf version between 3.8 and 4.0 while other code in AD unchanged. But we get different scores.
List<Double> scores = new ArrayList<>();
scores.add(1.86645896573027);
scores.add(1.8760247712797833);
scores.add(1.6809181763279901);
scores.add(1.7126716645678555);
scores.add(1.323776514074674);
scores.add(2.119532552959117);
scores.add(2.7347456872746325);
scores.add(3.066704948143919);
scores.add(3.2965580521876725);
scores.add(3.1888920146607047);

// rcf 3.8 has a number of improvements on thresholder and predictor corrector.
// We don't expect the results have the same anomaly grade.
for (int i = 0; i < coldStartData.size(); i++) {
forest.process(coldStartData.get(i), 0);
AnomalyDescriptor descriptor = forest.process(coldStartData.get(i), 0);
assertEquals(descriptor.getRCFScore(), scores.get(i), 1e-9);
assertEquals(scores.get(i), descriptor.getRCFScore(), 1e-9);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ public void testAccuracyOneMinuteIntervalNoInterpolation() throws Exception {
clusterService
);

accuracyTemplate(1, 0.6f, 0.6f);
accuracyTemplate(1, 0.5f, 0.5f);
}

private ModelState<EntityModel> createStateForCacheRelease() {
Expand Down
25 changes: 19 additions & 6 deletions src/test/java/org/opensearch/ad/task/ADTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.ad.ADUnitTestCase;
import org.opensearch.ad.cluster.HashRing;
import org.opensearch.ad.indices.ADIndexManagement;
import org.opensearch.ad.mock.model.MockSimpleLog;
Expand All @@ -89,6 +88,7 @@
import org.opensearch.ad.model.ADTaskType;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.rest.handler.IndexAnomalyDetectorJobActionHandler;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.stats.InternalStatNames;
import org.opensearch.ad.transport.ADStatsNodeResponse;
import org.opensearch.ad.transport.ADStatsNodesResponse;
Expand Down Expand Up @@ -120,6 +120,7 @@
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AbstractTimeSeriesTest;
import org.opensearch.timeseries.TestHelpers;
import org.opensearch.timeseries.common.exception.DuplicateTaskException;
import org.opensearch.timeseries.constant.CommonName;
Expand All @@ -139,7 +140,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class ADTaskManagerTests extends ADUnitTestCase {
public class ADTaskManagerTests extends AbstractTimeSeriesTest {

private Settings settings;
private Client client;
Expand Down Expand Up @@ -1447,10 +1448,22 @@ public void testForwardRequestToLeadNodeWithNotExistingNode() throws IOException
@SuppressWarnings("unchecked")
public void testScaleTaskLaneOnCoordinatingNode() {
ADTask adTask = mock(ADTask.class);
when(adTask.getCoordinatingNode()).thenReturn(node1.getId());
when(nodeFilter.getEligibleDataNodes()).thenReturn(new DiscoveryNode[] { node1, node2 });
ActionListener<JobResponse> listener = mock(ActionListener.class);
adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, transportService, listener);
try {
// bring up real transport service as mockito cannot mock final method
// and transportService.sendRequest is called. A lot of null pointer
// exception will be thrown if we use mocked transport service.
setUpThreadPool(ADTaskManagerTests.class.getSimpleName());
setupTestNodes(AnomalyDetectorSettings.AD_MAX_ENTITIES_PER_QUERY, AnomalyDetectorSettings.AD_PAGE_SIZE);
when(adTask.getCoordinatingNode()).thenReturn(testNodes[1].getNodeId());
when(nodeFilter.getEligibleDataNodes())
.thenReturn(new DiscoveryNode[] { testNodes[0].discoveryNode(), testNodes[1].discoveryNode() });
ActionListener<JobResponse> listener = mock(ActionListener.class);

adTaskManager.scaleTaskLaneOnCoordinatingNode(adTask, 2, testNodes[1].transportService, listener);
} finally {
tearDownTestNodes();
tearDownThreadPool();
}
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand All @@ -30,6 +33,8 @@
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.logging.log4j.Level;
Expand All @@ -40,18 +45,24 @@
import org.apache.logging.log4j.core.config.Property;
import org.apache.logging.log4j.core.layout.PatternLayout;
import org.apache.logging.log4j.util.StackLocatorUtil;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.model.AnomalyResult;
import org.opensearch.ad.model.DetectorInternalState;
import org.opensearch.cluster.metadata.AliasMetadata;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.logging.Loggers;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.http.HttpRequest;
Expand All @@ -67,10 +78,22 @@
import org.opensearch.transport.TransportInterceptor;
import org.opensearch.transport.TransportService;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;

import test.org.opensearch.ad.util.FakeNode;

public class AbstractTimeSeriesTest extends OpenSearchTestCase {

@Captor
protected ArgumentCaptor<Exception> exceptionCaptor;

@Override
public void setUp() throws Exception {
super.setUp();
MockitoAnnotations.initMocks(this);
}

protected static final Logger LOG = (Logger) LogManager.getLogger(AbstractTimeSeriesTest.class);

// transport test node
Expand Down Expand Up @@ -452,4 +475,39 @@ protected void setUpADThreadPool(ThreadPool mockThreadPool) {
return null;
}).when(executorService).execute(any(Runnable.class));
}

/**
* Create cluster setting.
*
* @param settings cluster settings
* @param setting add setting if the code to be tested contains setting update consumer
* @return instance of ClusterSettings
*/
public ClusterSettings clusterSetting(Settings settings, Setting<?>... setting) {
final Set<Setting<?>> settingsSet = Stream
.concat(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS.stream(), Sets.newHashSet(setting).stream())
.collect(Collectors.toSet());
ClusterSettings clusterSettings = new ClusterSettings(settings, settingsSet);
return clusterSettings;
}

protected DiscoveryNode createNode(String nodeId) {
return new DiscoveryNode(
nodeId,
new TransportAddress(TransportAddress.META_ADDRESS, 9300),
ImmutableMap.of(),
BUILT_IN_ROLES,
Version.CURRENT
);
}

protected DiscoveryNode createNode(String nodeId, String ip, int port, Map<String, String> attributes) throws UnknownHostException {
return new DiscoveryNode(
nodeId,
new TransportAddress(InetAddress.getByName(ip), port),
attributes,
BUILT_IN_ROLES,
Version.CURRENT
);
}
}

0 comments on commit 4445117

Please sign in to comment.