Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce execution hint for Cardinality aggregation #15764

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation

private static final ParseField REHASH = new ParseField("rehash").withAllDeprecated("no replacement - values will always be rehashed");
public static final ParseField PRECISION_THRESHOLD_FIELD = new ParseField("precision_threshold");
public static final ParseField EXECUTION_HINT_FIELD = new ParseField(("execution_hint"));

public static final ObjectParser<CardinalityAggregationBuilder, String> PARSER = ObjectParser.fromBuilder(
NAME,
Expand All @@ -76,6 +77,7 @@ public final class CardinalityAggregationBuilder extends ValuesSourceAggregation
static {
ValuesSourceAggregationBuilder.declareFields(PARSER, true, false, false);
PARSER.declareLong(CardinalityAggregationBuilder::precisionThreshold, CardinalityAggregationBuilder.PRECISION_THRESHOLD_FIELD);
PARSER.declareString(CardinalityAggregationBuilder::executionHint, CardinalityAggregationBuilder.EXECUTION_HINT_FIELD);
PARSER.declareLong((b, v) -> {/*ignore*/}, REHASH);
}

Expand All @@ -85,6 +87,8 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) {

private Long precisionThreshold = null;

private String executionHint = null;

public CardinalityAggregationBuilder(String name) {
super(name);
}
Expand All @@ -96,6 +100,7 @@ public CardinalityAggregationBuilder(
) {
super(clone, factoriesBuilder, metadata);
this.precisionThreshold = clone.precisionThreshold;
this.executionHint = clone.executionHint;
}

@Override
Expand All @@ -111,6 +116,7 @@ public CardinalityAggregationBuilder(StreamInput in) throws IOException {
if (in.readBoolean()) {
precisionThreshold = in.readLong();
}
executionHint = in.readOptionalString();
}

@Override
Expand All @@ -125,6 +131,7 @@ protected void innerWriteTo(StreamOutput out) throws IOException {
if (hasPrecisionThreshold) {
out.writeLong(precisionThreshold);
}
out.writeOptionalString(executionHint);
}

@Override
Expand Down Expand Up @@ -155,27 +162,48 @@ public Long precisionThreshold() {
return precisionThreshold;
}

public CardinalityAggregationBuilder executionHint(String executionHint) {
this.executionHint = executionHint;
return this;
}

public String executionHint() {
return executionHint;
}

@Override
protected CardinalityAggregatorFactory innerBuild(
QueryShardContext queryShardContext,
ValuesSourceConfig config,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder
) throws IOException {
return new CardinalityAggregatorFactory(name, config, precisionThreshold, queryShardContext, parent, subFactoriesBuilder, metadata);
return new CardinalityAggregatorFactory(
name,
config,
precisionThreshold,
executionHint,
queryShardContext,
parent,
subFactoriesBuilder,
metadata
);
}

@Override
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
if (precisionThreshold != null) {
builder.field(PRECISION_THRESHOLD_FIELD.getPreferredName(), precisionThreshold);
}
if (executionHint != null) {
builder.field(EXECUTION_HINT_FIELD.getPreferredName(), executionHint);
}
return builder;
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), precisionThreshold);
return Objects.hash(super.hashCode(), precisionThreshold, executionHint);
}

@Override
Expand All @@ -184,7 +212,7 @@ public boolean equals(Object obj) {
if (obj == null || getClass() != obj.getClass()) return false;
if (super.equals(obj) == false) return false;
CardinalityAggregationBuilder other = (CardinalityAggregationBuilder) obj;
return Objects.equals(precisionThreshold, other.precisionThreshold);
return Objects.equals(precisionThreshold, other.precisionThreshold) && Objects.equals(executionHint, other.executionHint);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class CardinalityAggregator extends NumericMetricsAggregator.SingleValue

private static final Logger logger = LogManager.getLogger(CardinalityAggregator.class);

private final CardinalityAggregatorFactory.ExecutionMode executionMode;
private final int precision;
private final ValuesSource valuesSource;

Expand All @@ -111,6 +112,7 @@ public CardinalityAggregator(
String name,
ValuesSourceConfig valuesSourceConfig,
int precision,
CardinalityAggregatorFactory.ExecutionMode executionMode,
SearchContext context,
Aggregator parent,
Map<String, Object> metadata
Expand All @@ -121,6 +123,7 @@ public CardinalityAggregator(
this.precision = precision;
this.counts = valuesSource == null ? null : new HyperLogLogPlusPlus(precision, context.bigArrays(), 1);
this.valuesSourceConfig = valuesSourceConfig;
this.executionMode = executionMode;
}

@Override
Expand Down Expand Up @@ -151,6 +154,9 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
if (maxOrd == 0) {
emptyCollectorsUsed++;
return new EmptyCollector();
} else if (executionMode == CardinalityAggregatorFactory.ExecutionMode.ORDINALS) { // Force OrdinalsCollector
ordinalsCollectorsUsed++;
collector = new OrdinalsCollector(counts, ordinalValues, context.bigArrays());
} else {
final long ordinalsMemoryUsage = OrdinalsCollector.memoryOverhead(maxOrd);
final long countsMemoryUsage = HyperLogLogPlusPlus.memoryUsage(precision);
Expand Down Expand Up @@ -480,7 +486,7 @@ public void close() {
*
* @opensearch.internal
*/
private static class DirectCollector extends Collector {
public static class DirectCollector extends Collector {

private final MurmurHash3Values hashes;
private final HyperLogLogPlusPlus counts;
Expand Down Expand Up @@ -517,7 +523,7 @@ public void close() {
*
* @opensearch.internal
*/
private static class OrdinalsCollector extends Collector {
public static class OrdinalsCollector extends Collector {

private static final long SHALLOW_FIXEDBITSET_SIZE = RamUsageEstimator.shallowSizeOfInstance(FixedBitSet.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.Locale;
import java.util.Map;

/**
Expand All @@ -53,19 +54,53 @@
*/
class CardinalityAggregatorFactory extends ValuesSourceAggregatorFactory {

/**
* Execution mode for cardinality agg
*
* @opensearch.internal
*/
public enum ExecutionMode {

UNSET,
DIRECT,
ORDINALS;

ExecutionMode() {}

public static ExecutionMode fromString(String value) {
if (value == null) {
return UNSET;
}
try {
return ExecutionMode.valueOf(value.toUpperCase(Locale.ROOT));
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Unknown `execution_hint`: [" + value + "], expected any of [direct, ordinals]");
}
}

@Override
public String toString() {
return this.name().toLowerCase(Locale.ROOT);
}
}

private final ExecutionMode executionMode;

private final Long precisionThreshold;

CardinalityAggregatorFactory(
String name,
ValuesSourceConfig config,
Long precisionThreshold,
String executionHint,
QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metadata
) throws IOException {
super(name, config, queryShardContext, parent, subFactoriesBuilder, metadata);
this.precisionThreshold = precisionThreshold;
this.executionMode = ExecutionMode.fromString(executionHint);
}

public static void registerAggregators(ValuesSourceRegistry.Builder builder) {
Expand All @@ -74,7 +109,7 @@ public static void registerAggregators(ValuesSourceRegistry.Builder builder) {

@Override
protected Aggregator createUnmapped(SearchContext searchContext, Aggregator parent, Map<String, Object> metadata) throws IOException {
return new CardinalityAggregator(name, config, precision(), searchContext, parent, metadata);
return new CardinalityAggregator(name, config, precision(), executionMode, searchContext, parent, metadata);
}

@Override
Expand All @@ -86,7 +121,7 @@ protected Aggregator doCreateInternal(
) throws IOException {
return queryShardContext.getValuesSourceRegistry()
.getAggregator(CardinalityAggregationBuilder.REGISTRY_KEY, config)
.build(name, config, precision(), searchContext, parent, metadata);
.build(name, config, precision(), executionMode, searchContext, parent, metadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Aggregator build(
String name,
ValuesSourceConfig valuesSourceConfig,
int precision,
CardinalityAggregatorFactory.ExecutionMode executionMode,
SearchContext context,
Aggregator parent,
Map<String, Object> metadata
Expand Down
Loading
Loading