Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for $median aggregation expression #4515

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4472-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4472-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4472-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4472-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,16 @@ public Percentile percentile(Double... percentages) {
return percentile.percentages(percentages);
}

/**
* Creates new {@link AggregationExpression} that calculates the median of the associated numeric value expression.
*
* @return new instance of {@link Median}.
* @since 4.2
*/
public Median median() {
return usesFieldRef() ? Median.medianOf(fieldReference) : Median.medianOf(expression);
}

private boolean usesFieldRef() {
return fieldReference != null;
}
Expand Down Expand Up @@ -1082,4 +1092,78 @@ protected String getMongoMethod() {
return "$percentile";
}
}

/**
* {@link AggregationExpression} for {@code $median}.
*
* @author Julia Lee
* @since 4.2
*/
public static class Median extends AbstractAggregationExpression {

private Median(Object value) {
super(value);
}

/**
* Creates new {@link Median}.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public static Median medianOf(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null");
Map<String, Object> fields = new HashMap<>();
fields.put("input", Fields.field(fieldReference));
fields.put("method", "approximate");
return new Median(fields);
}

/**
* Creates new {@link Median}.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public static Median medianOf(AggregationExpression expression) {

Assert.notNull(expression, "Expression must not be null");
Map<String, Object> fields = new HashMap<>();
fields.put("input", expression);
fields.put("method", "approximate");
return new Median(fields);
}

/**
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
*
* @param fieldReference must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public Median and(String fieldReference) {

Assert.notNull(fieldReference, "FieldReference must not be null");
return new Median(appendTo("input", Fields.field(fieldReference)));
}

/**
* Creates new {@link Median} with all previously added inputs appending the given one. <br />
* <strong>NOTE:</strong> Only possible in {@code $project} stage.
*
* @param expression must not be {@literal null}.
* @return new instance of {@link Median}.
*/
public Median and(AggregationExpression expression) {

Assert.notNull(expression, "Expression must not be null");
return new Median(appendTo("input", expression));
}

@Override
protected String getMongoMethod() {
return "$median";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Median;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Percentile;
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
Expand Down Expand Up @@ -948,6 +949,18 @@ public Percentile percentile(Double... percentages) {
return percentile.percentages(percentages);
}

/**
* Creates new {@link AggregationExpression} that calculates the requested percentile(s) of the
* numeric value.
*
* @return new instance of {@link Median}.
* @since 4.2
*/
public Median median() {
return usesFieldRef() ? AccumulatorOperators.Median.medianOf(fieldReference)
: AccumulatorOperators.Median.medianOf(expression);
}

private boolean usesFieldRef() {
return fieldReference != null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ void rendersPercentileWithExpression() {
.isEqualTo(Document.parse("{ $percentile: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\", p: [0.1, 0.2] } }"));
}

@Test // GH-4472
void rendersMedianWithFieldReference() {

assertThat(valueOf("score").median().toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: \"$score\", method: \"approximate\" } }"));

assertThat(valueOf("score").median().and("scoreTwo").toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }"));
}

@Test // GH-4472
void rendersMedianWithExpression() {

assertThat(valueOf(Sum.sumOf("score")).median().toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: {\"$sum\": \"$score\"}, method: \"approximate\" } }"));

assertThat(valueOf("scoreOne").median().and(Sum.sumOf("scoreTwo")).toDocument(Aggregation.DEFAULT_CONTEXT))
.isEqualTo(Document.parse("{ $median: { input: [\"$scoreOne\", {\"$sum\": \"$scoreTwo\"}], method: \"approximate\" } }"));
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add one more test using a TypeBasedAggregationOperationContext?
something like:

assertThat(valueOf("midichlorianCount").median().toDocument(contextFor(Jedi.class)))
    .isEqualTo("{ $median: { input: '$force', ...

@Test // GH-4472
void rendersMedianCorrectlyWithTypedAggregationContext() {

assertThat(valueOf("midichlorianCount").median()
.toDocument(TestAggregationContext.contextFor(Jedi.class)))
.isEqualTo(Document.parse("{ $median: { input: \"$force\", method: \"approximate\" } }"));
}

static class Jedi {

String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1897,19 +1897,44 @@ void facetShouldCreateFacets() {
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
void percentileShouldBeAppliedCorrectly() {

mongoTemplate.insert(new DATAMONGO788(15, 16));
mongoTemplate.insert(new DATAMONGO788(17, 18));
DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);

mongoTemplate.insert(objectToSave);
mongoTemplate.insert(objectToSave2);

Aggregation agg = Aggregation.newAggregation(
project().and(ArithmeticOperators.valueOf("x").percentile(0.9).and("y"))
.as("ninetiethPercentile"));
project().and(ArithmeticOperators.valueOf("x").percentile(0.9, 0.4).and("y").and("xField"))
.as("percentileValues"));

AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);

// MongoDB server returns $percentile as an array of doubles
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
assertThat((List<Object>) rawResults.get(0).get("ninetiethPercentile")).containsExactly(16.0);
assertThat((List<Object>) rawResults.get(1).get("ninetiethPercentile")).containsExactly(18.0);
assertThat((List<Object>) rawResults.get(0).get("percentileValues")).containsExactly(81.0, 80.0);
assertThat((List<Object>) rawResults.get(1).get("percentileValues")).containsExactly(83.0, 79.0);
}

@Test // GH-4472
@EnableIfMongoServerVersion(isGreaterThanEqual = "7.0")
void medianShouldBeAppliedCorrectly() {

DATAMONGO788 objectToSave = new DATAMONGO788(62, 81, 80);
DATAMONGO788 objectToSave2 = new DATAMONGO788(60, 83, 79);

mongoTemplate.insert(objectToSave);
mongoTemplate.insert(objectToSave2);

Aggregation agg = Aggregation.newAggregation(
project().and(ArithmeticOperators.valueOf("x").median().and("y").and("xField"))
.as("medianValue"));

AggregationResults<Document> result = mongoTemplate.aggregate(agg, DATAMONGO788.class, Document.class);

// MongoDB server returns $median a Double
List<Document> rawResults = (List<Document>) result.getRawResults().get("results");
assertThat(rawResults.get(0).get("medianValue")).isEqualTo(80.0);
assertThat(rawResults.get(1).get("medianValue")).isEqualTo(79.0);
}

@Test // DATAMONGO-1986
Expand Down Expand Up @@ -2152,6 +2177,12 @@ public DATAMONGO788() {}
this.y = y;
this.yField = y;
}

public DATAMONGO788(int x, int y, int xField) {
this.x = x;
this.y = y;
this.xField = xField;
}
}

// DATAMONGO-806
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import org.junit.jupiter.api.Test;

/**
* Unit tests for {@link Round}.
* Unit tests for {@link ArithmeticOperators}.
*
* @author Christoph Strobl
* @author Mark Paluch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,22 @@ void shouldRenderPercentileWithMultipleArgsAggregationExpression() {
assertThat(agg).isEqualTo(Document.parse("{ $project: { scorePercentiles: { $percentile: { input: [\"$scoreOne\", \"$scoreTwo\"], method: \"approximate\", p: [0.4] } }} } }"));
}

@Test // GH-4472
void shouldRenderMedianAggregationExpressions() {

Document singleArgAgg = project()
.and(ArithmeticOperators.valueOf("score").median()).as("medianValue")
.toDocument(Aggregation.DEFAULT_CONTEXT);

assertThat(singleArgAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: \"$score\", method: \"approximate\" } }} } }"));

Document multipleArgsAgg = project()
.and(ArithmeticOperators.valueOf("score").median().and("scoreTwo")).as("medianValue")
.toDocument(Aggregation.DEFAULT_CONTEXT);

assertThat(multipleArgsAgg).isEqualTo(Document.parse("{ $project: { medianValue: { $median: { input: [\"$score\", \"$scoreTwo\"], method: \"approximate\" } }} } }"));
}

private static Document extractOperation(String field, Document fromProjectClause) {
return (Document) fromProjectClause.get(field);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ At the time of this writing, we provide support for the following Aggregation Op
| `setEquals`, `setIntersection`, `setUnion`, `setDifference`, `setIsSubset`, `anyElementTrue`, `allElementsTrue`

| Group/Accumulator Aggregation Operators
| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `percentile`, `stdDevPop`, `stdDevSamp`
| `addToSet`, `bottom`, `bottomN`, `covariancePop`, `covarianceSamp`, `expMovingAvg`, `first`, `firstN`, `last`, `lastN` `max`, `maxN`, `min`, `minN`, `avg`, `push`, `sum`, `top`, `topN`, `count` (+++*+++), `median`, `percentile`, `stdDevPop`, `stdDevSamp`

| Arithmetic Aggregation Operators
| `abs`, `acos`, `acosh`, `add` (+++*+++ via `plus`), `asin`, `asin`, `atan`, `atan2`, `atanh`, `ceil`, `cos`, `cosh`, `derivative`, `divide`, `exp`, `floor`, `integral`, `ln`, `log`, `log10`, `mod`, `multiply`, `pow`, `round`, `sqrt`, `subtract` (+++*+++ via `minus`), `sin`, `sinh`, `tan`, `tanh`, `trunc`
Expand Down