Skip to content

Commit

Permalink
refactoring and addressing comments
Browse files Browse the repository at this point in the history
Signed-off-by: Bharathwaj G <[email protected]>
  • Loading branch information
bharath-techie committed Aug 18, 2024
1 parent c647863 commit b139000
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.index.compositeindex.datacube.startree.StarTreeField;
import org.opensearch.index.compositeindex.datacube.startree.builder.StarTreesBuilder;
import org.opensearch.index.mapper.CompositeMappedFieldType;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.index.mapper.MapperService;

import java.io.IOException;
Expand Down Expand Up @@ -61,15 +62,19 @@ public Composite99DocValuesWriter(DocValuesConsumer delegate, SegmentWriteState
this.compositeMappedFieldTypes = mapperService.getCompositeFieldTypes();
compositeFieldSet = new HashSet<>();
segmentFieldSet = new HashSet<>();
// TODO : add integ test for this
for (FieldInfo fi : segmentWriteState.fieldInfos) {
if (DocValuesType.SORTED_NUMERIC.equals(fi.getDocValuesType())) {
segmentFieldSet.add(fi.name);
} else if (fi.name.equals(DocCountFieldMapper.NAME)) {
segmentFieldSet.add(fi.name);
}
}
for (CompositeMappedFieldType type : compositeMappedFieldTypes) {
compositeFieldSet.addAll(type.fields());
}
// check if there are any composite fields which are part of the segment
// TODO : add integ test where there are no composite fields in a segment, test both flush and merge cases
segmentHasCompositeFields = Collections.disjoint(segmentFieldSet, compositeFieldSet) == false;
}

Expand Down Expand Up @@ -121,22 +126,7 @@ private void createCompositeIndicesIfPossible(DocValuesProducer valuesProducer,
if (segmentFieldSet.isEmpty()) {
Set<String> compositeFieldSetCopy = new HashSet<>(compositeFieldSet);
for (String compositeField : compositeFieldSetCopy) {
if (compositeField.equals("_doc_count")) {
fieldProducerMap.put(compositeField, new EmptyDocValuesProducer() {
@Override
public NumericDocValues getNumeric(FieldInfo field) {
return DocValues.emptyNumeric();
}
});
} else {
fieldProducerMap.put(compositeField, new EmptyDocValuesProducer() {
@Override
public SortedNumericDocValues getSortedNumeric(FieldInfo field) {
return DocValues.emptySortedNumeric();
}
});
}
compositeFieldSet.remove(compositeField);
addDocValuesForEmptyField(compositeField);
}
}
// we have all the required fields to build composite fields
Expand All @@ -149,7 +139,28 @@ public SortedNumericDocValues getSortedNumeric(FieldInfo field) {
}
}
}
}

/**
* Add empty doc values for fields not present in segment
*/
private void addDocValuesForEmptyField(String compositeField) {
if (compositeField.equals(DocCountFieldMapper.NAME)) {
fieldProducerMap.put(compositeField, new EmptyDocValuesProducer() {
@Override
public NumericDocValues getNumeric(FieldInfo field) {
return DocValues.emptyNumeric();
}
});
} else {
fieldProducerMap.put(compositeField, new EmptyDocValuesProducer() {
@Override
public SortedNumericDocValues getSortedNumeric(FieldInfo field) {
return DocValues.emptySortedNumeric();
}
});
}
compositeFieldSet.remove(compositeField);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class CountValueAggregator implements ValueAggregator<Long> {
public static final long DEFAULT_INITIAL_VALUE = 1L;
private static final StarTreeNumericType VALUE_AGGREGATOR_TYPE = StarTreeNumericType.LONG;

public CountValueAggregator(StarTreeNumericType starTreeNumericType) {}
public CountValueAggregator() {}

@Override
public StarTreeNumericType getAggregatedValueType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ public class DocCountAggregator implements ValueAggregator<Long> {

private static final StarTreeNumericType VALUE_AGGREGATOR_TYPE = StarTreeNumericType.LONG;

public DocCountAggregator(StarTreeNumericType starTreeNumericType) {}
public DocCountAggregator() {}

@Override
public StarTreeNumericType getAggregatedValueType() {
return VALUE_AGGREGATOR_TYPE;
}

/**
* If _doc_count field for a doc is missing, we increment the _doc_count by '1' for the associated doc
* otherwise take the actual value present in the field
*/
@Override
public Long getInitialAggregatedValueForSegmentDocValue(Long segmentDocValue) {
if (segmentDocValue == null) {
Expand Down Expand Up @@ -56,6 +60,9 @@ public Long toStarTreeNumericTypeValue(Long value) {
return value;
}

/**
* If _doc_count field for a doc is missing, we increment the _doc_count by '1' for the associated doc
*/
@Override
public Long getIdentityMetricValue() {
return 1L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ public static ValueAggregator getValueAggregator(MetricStat aggregationType, Sta
case SUM:
return new SumValueAggregator(starTreeNumericType);
case COUNT:
return new CountValueAggregator(starTreeNumericType);
return new CountValueAggregator();
case MIN:
return new MinValueAggregator(starTreeNumericType);
case MAX:
return new MaxValueAggregator(starTreeNumericType);
case DOC_COUNT:
return new DocCountAggregator(starTreeNumericType);
return new DocCountAggregator();
default:
throw new IllegalStateException("Unsupported aggregation type: " + aggregationType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.index.compositeindex.datacube.startree.utils.SequentialDocValuesIterator;
import org.opensearch.index.compositeindex.datacube.startree.utils.TreeNode;
import org.opensearch.index.fielddata.IndexNumericFieldData;
import org.opensearch.index.mapper.DocCountFieldMapper;
import org.opensearch.index.mapper.Mapper;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.NumberFieldMapper;
Expand Down Expand Up @@ -118,7 +119,7 @@ protected BaseStarTreeBuilder(StarTreeField starTreeField, SegmentWriteState sta
public List<MetricAggregatorInfo> generateMetricAggregatorInfos(MapperService mapperService) {
List<MetricAggregatorInfo> metricAggregatorInfos = new ArrayList<>();
for (Metric metric : this.starTreeField.getMetrics()) {
if (metric.getField().equals("_doc_count")) {
if (metric.getField().equals(DocCountFieldMapper.NAME)) {
MetricAggregatorInfo metricAggregatorInfo = new MetricAggregatorInfo(
MetricStat.DOC_COUNT,
metric.getField(),
Expand Down Expand Up @@ -437,7 +438,7 @@ public void build(Map<String, DocValuesProducer> fieldProducerMap) throws IOExce
String dimension = dimensionsSplitOrder.get(i).getField();
FieldInfo dimensionFieldInfo = state.fieldInfos.fieldInfo(dimension);
if (dimensionFieldInfo == null) {
dimensionFieldInfo = getFieldInfo(dimension);
dimensionFieldInfo = getFieldInfo(dimension, DocValuesType.SORTED_NUMERIC);
}
dimensionReaders[i] = new SequentialDocValuesIterator(
fieldProducerMap.get(dimensionFieldInfo.name).getSortedNumeric(dimensionFieldInfo)
Expand All @@ -449,15 +450,15 @@ public void build(Map<String, DocValuesProducer> fieldProducerMap) throws IOExce
logger.debug("Finished Building star-tree in ms : {}", (System.currentTimeMillis() - startTime));
}

private static FieldInfo getFieldInfo(String field) {
private static FieldInfo getFieldInfo(String field, DocValuesType docValuesType) {
return new FieldInfo(
field,
1,
1, // This is filled as part of doc values creation and is not used otherwise
false,
false,
false,
IndexOptions.NONE,
DocValuesType.SORTED_NUMERIC,
docValuesType,
-1,
Collections.emptyMap(),
0,
Expand All @@ -483,12 +484,12 @@ public List<SequentialDocValuesIterator> getMetricReaders(SegmentWriteState stat
for (MetricStat metricStat : metric.getMetrics()) {
SequentialDocValuesIterator metricReader = null;
FieldInfo metricFieldInfo = state.fieldInfos.fieldInfo(metric.getField());
if (metricFieldInfo == null) {
metricFieldInfo = getFieldInfo(metric.getField());
}
if (metricStat.equals(MetricStat.DOC_COUNT)) {
metricReader = getDocCountMetricReader(fieldProducerMap, metricFieldInfo);
} else {
if (metricFieldInfo == null) {
metricFieldInfo = getFieldInfo(metric.getField(), DocValuesType.SORTED_NUMERIC);
}
metricReader = new SequentialDocValuesIterator(
fieldProducerMap.get(metricFieldInfo.name).getSortedNumeric(metricFieldInfo)
);
Expand All @@ -499,19 +500,17 @@ public List<SequentialDocValuesIterator> getMetricReaders(SegmentWriteState stat
return metricReaders;
}

private static SequentialDocValuesIterator getDocCountMetricReader(
Map<String, DocValuesProducer> fieldProducerMap,
FieldInfo metricFieldInfo
) throws IOException {
SequentialDocValuesIterator metricReader;
// _doc_count is numeric field , so we need to get sortedNumericDocValues
if (fieldProducerMap.containsKey(metricFieldInfo.name)) {
metricReader = new SequentialDocValuesIterator(
DocValues.singleton(fieldProducerMap.get(metricFieldInfo.name).getNumeric(metricFieldInfo))
);
} else {
metricReader = new SequentialDocValuesIterator(DocValues.emptySortedNumeric());
private SequentialDocValuesIterator getDocCountMetricReader(Map<String, DocValuesProducer> fieldProducerMap, FieldInfo metricFieldInfo)
throws IOException {
if (metricFieldInfo == null) {
metricFieldInfo = getFieldInfo(DocCountFieldMapper.NAME, DocValuesType.NUMERIC);
}
SequentialDocValuesIterator metricReader;
assert fieldProducerMap.containsKey(metricFieldInfo.name);
// _doc_count is numeric field , so we need to get convert to sortedNumericDocValues
metricReader = new SequentialDocValuesIterator(
DocValues.singleton(fieldProducerMap.get(metricFieldInfo.name).getNumeric(metricFieldInfo))
);
return metricReader;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.opensearch.common.annotation.ExperimentalApi;
Expand Down Expand Up @@ -139,12 +137,6 @@ Iterator<StarTreeDocument> mergeStarTrees(List<StarTreeValues> starTreeValuesSub
}
List<SequentialDocValuesIterator> metricReaders = new ArrayList<>();
for (Map.Entry<String, DocIdSetIterator> metricDocValuesEntry : starTreeValues.getMetricDocValuesIteratorMap().entrySet()) {
if (metricDocValuesEntry.getValue() instanceof NumericDocValues) {
metricReaders.add(
new SequentialDocValuesIterator(DocValues.singleton((NumericDocValues) metricDocValuesEntry.getValue()))
);
continue;
}
metricReaders.add(new SequentialDocValuesIterator(metricDocValuesEntry.getValue()));
}
int currentDocId = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ private List<Metric> buildMetrics(String fieldName, Map<String, Object> map, Map
for (Object metric : metricsList) {
Map<String, Object> metricMap = (Map<String, Object>) metric;
String name = (String) XContentMapValues.extractValue(CompositeDataCubeFieldType.NAME, metricMap);
if (name.equals("_doc_count")) {
// Handle _doc_count metric separately at the end
if (name.equals(DocCountFieldMapper.NAME)) {
continue;
}
metricMap.remove(CompositeDataCubeFieldType.NAME);
Expand All @@ -252,7 +253,7 @@ private List<Metric> buildMetrics(String fieldName, Map<String, Object> map, Map
} else {
throw new MapperParsingException(String.format(Locale.ROOT, "unable to parse metrics for star tree field [%s]", this.name));
}
Metric docCountMetric = new Metric("_doc_count", List.of(MetricStat.DOC_COUNT));
Metric docCountMetric = new Metric(DocCountFieldMapper.NAME, List.of(MetricStat.DOC_COUNT));
metrics.add(docCountMetric);
return metrics;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ public void testGetInitialAggregatedValueForSegmentDocNullValue() {
}

public void testMergeAggregatedNullValueAndSegmentNullValue() {
if (aggregator instanceof CountValueAggregator) {
assertThrows(AssertionError.class, () -> aggregator.mergeAggregatedValueAndSegmentValue(null, null));
} else {
assertEquals(aggregator.getIdentityMetricValue(), aggregator.mergeAggregatedValueAndSegmentValue(null, null));
}
assertEquals(aggregator.getIdentityMetricValue(), aggregator.mergeAggregatedValueAndSegmentValue(null, null));
}

public void testMergeAggregatedNullValues() {
Expand All @@ -65,13 +61,6 @@ public void testGetInitialAggregatedNullValue() {

public void testGetInitialAggregatedValueForSegmentDocValue() {
long randomLong = randomLong();
if (aggregator instanceof CountValueAggregator) {
assertEquals(CountValueAggregator.DEFAULT_INITIAL_VALUE, aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong()));
} else {
assertEquals(
aggregator.toStarTreeNumericTypeValue(randomLong),
aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong)
);
}
assertEquals(aggregator.toStarTreeNumericTypeValue(randomLong), aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ public void testMergeAggregatedValues() {
assertEquals(randomLong2, aggregator.mergeAggregatedValues(null, randomLong2), 0.0);
}

@Override
public void testMergeAggregatedNullValueAndSegmentNullValue() {
assertThrows(AssertionError.class, () -> aggregator.mergeAggregatedValueAndSegmentValue(null, null));
}

public void testGetInitialAggregatedValue() {
long randomLong = randomLong();
assertEquals(randomLong, aggregator.getInitialAggregatedValue(randomLong), 0.0);
Expand All @@ -48,8 +53,13 @@ public void testIdentityMetricValue() {

@Override
public ValueAggregator getValueAggregator(StarTreeNumericType starTreeNumericType) {
aggregator = new CountValueAggregator(starTreeNumericType);
aggregator = new CountValueAggregator();
return aggregator;
}

@Override
public void testGetInitialAggregatedValueForSegmentDocValue() {
long randomLong = randomLong();
assertEquals(CountValueAggregator.DEFAULT_INITIAL_VALUE, (long) aggregator.getInitialAggregatedValueForSegmentDocValue(randomLong));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.index.compositeindex.datacube.startree.aggregators;

import org.opensearch.index.compositeindex.datacube.startree.aggregators.numerictype.StarTreeNumericType;

/**
* Unit tests for {@link DocCountAggregator}.
*/
public class DocCountAggregatorTests extends AbstractValueAggregatorTests {

private DocCountAggregator aggregator;

public DocCountAggregatorTests(StarTreeNumericType starTreeNumericType) {
super(starTreeNumericType);
}

public void testMergeAggregatedValueAndSegmentValue() {
long randomLong = randomLong();
assertEquals(randomLong + 3L, (long) aggregator.mergeAggregatedValueAndSegmentValue(randomLong, 3L));
}

public void testMergeAggregatedValues() {
long randomLong1 = randomLong();
long randomLong2 = randomLong();
assertEquals(randomLong1 + randomLong2, (long) aggregator.mergeAggregatedValues(randomLong1, randomLong2));
assertEquals(randomLong1 + 1L, (long) aggregator.mergeAggregatedValues(randomLong1, null));
assertEquals(randomLong2 + 1L, (long) aggregator.mergeAggregatedValues(null, randomLong2));
}

@Override
public void testMergeAggregatedNullValueAndSegmentNullValue() {
assertThrows(AssertionError.class, () -> aggregator.mergeAggregatedValueAndSegmentValue(null, null));
}

@Override
public void testMergeAggregatedNullValues() {
assertEquals(
(aggregator.getIdentityMetricValue() + aggregator.getIdentityMetricValue()),
(long) aggregator.mergeAggregatedValues(null, null)
);
}

public void testGetInitialAggregatedValue() {
long randomLong = randomLong();
assertEquals(randomLong, (long) aggregator.getInitialAggregatedValue(randomLong));
}

public void testToStarTreeNumericTypeValue() {
long randomLong = randomLong();
assertEquals(randomLong, aggregator.toStarTreeNumericTypeValue(randomLong), 0.0);
assertNull(aggregator.toStarTreeNumericTypeValue(null));
}

public void testIdentityMetricValue() {
assertEquals(1L, (long) aggregator.getIdentityMetricValue());
}

@Override
public ValueAggregator getValueAggregator(StarTreeNumericType starTreeNumericType) {
aggregator = new DocCountAggregator();
return aggregator;
}
}
Loading

0 comments on commit b139000

Please sign in to comment.