Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add multi-column support to AggFormula #6206

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
//
package io.deephaven.engine.table.impl.by;

import io.deephaven.api.ColumnName;
import io.deephaven.api.Pair;
import io.deephaven.api.SortColumn;
import io.deephaven.api.*;
import io.deephaven.api.agg.*;
import io.deephaven.api.agg.spec.AggSpec;
import io.deephaven.api.agg.spec.AggSpecAbsSum;
Expand Down Expand Up @@ -93,13 +91,15 @@
import io.deephaven.engine.table.impl.by.ssmcountdistinct.unique.ShortRollupUniqueOperator;
import io.deephaven.engine.table.impl.by.ssmminmax.SsmChunkedMinMaxOperator;
import io.deephaven.engine.table.impl.by.ssmpercentile.SsmChunkedPercentileOperator;
import io.deephaven.engine.table.impl.select.SelectColumn;
import io.deephaven.engine.table.impl.sources.ReinterpretUtils;
import io.deephaven.engine.table.impl.ssms.SegmentedSortedMultiSet;
import io.deephaven.engine.table.impl.util.freezeby.FreezeByCountOperator;
import io.deephaven.engine.table.impl.util.freezeby.FreezeByOperator;
import io.deephaven.time.DateTimeUtils;
import io.deephaven.util.annotations.FinalDefault;
import io.deephaven.util.type.ArrayTypeUtils;
import io.deephaven.vector.VectorFactory;
import org.apache.commons.lang3.mutable.MutableBoolean;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
Expand All @@ -113,6 +113,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -153,6 +154,13 @@ private enum Type {
private final Collection<? extends Aggregation> aggregations;
private final Type type;

/**
* For {@link Formula formula} aggregations we need a representation of the table definition with the column data
* types converted to {@link io.deephaven.vector.Vector vectors}. This can be computed once and re-used across all
* formula aggregations.
*/
private Map<String, ColumnDefinition<?>> vectorColumnDefinitions;

/**
* Convert a collection of {@link Aggregation aggregations} to an {@link AggregationContextFactory}.
*
Expand Down Expand Up @@ -707,6 +715,51 @@ public void visit(@NotNull final Partition partition) {
groupByColumnNames));
}

@Override
public void visit(@NotNull final Formula formula) {
rcaudy marked this conversation as resolved.
Show resolved Hide resolved
final SelectColumn selectColumn = SelectColumn.of(formula.selectable());
lbooker42 marked this conversation as resolved.
Show resolved Hide resolved

// Get or create a column definition map composed of vectors of the original column types (or scalars when
// part of the group_by columns).
final Set<String> groupByColumnSet = Set.of(groupByColumnNames);
if (vectorColumnDefinitions == null) {
vectorColumnDefinitions = table.getDefinition().getColumnStream().collect(Collectors.toMap(
ColumnDefinition::getName,
(final ColumnDefinition<?> cd) -> groupByColumnSet.contains(cd.getName())
? cd
: ColumnDefinition.fromGenericType(
cd.getName(),
VectorFactory.forElementType(cd.getDataType()).vectorType(),
cd.getDataType())));
}

// Get the input column names from the formula and provide them to the groupBy operator
final String[] allInputColumns =
selectColumn.initDef(vectorColumnDefinitions, compilationProcessor).toArray(String[]::new);
final String[] inputKeyColumns = Arrays.stream(allInputColumns)
.filter(groupByColumnSet::contains)
.toArray(String[]::new);
final String[] inputNonKeyColumns = Arrays.stream(allInputColumns)
.filter(col -> !groupByColumnSet.contains(col))
.toArray(String[]::new);

if (!selectColumn.getColumnArrays().isEmpty()) {
throw new IllegalArgumentException("AggFormula does not support column arrays ("
+ selectColumn.getColumnArrays() + ")");
}
if (selectColumn.hasVirtualRowVariables()) {
lbooker42 marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException("AggFormula does not support virtual row variables");
}
// TODO: re-use shared groupBy operators (https://github.com/deephaven/deephaven-core/issues/6363)
final GroupByChunkedOperator groupByChunkedOperator = new GroupByChunkedOperator(table, false, null,
Arrays.stream(inputNonKeyColumns).map(col -> MatchPair.of(Pair.parse(col)))
.toArray(MatchPair[]::new));

final FormulaMultiColumnChunkedOperator op = new FormulaMultiColumnChunkedOperator(table,
groupByChunkedOperator, true, selectColumn, inputKeyColumns);
addNoInputOperator(op);
}

// -------------------------------------------------------------------------------------------------------------
// AggSpec.Visitor
// -------------------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -745,6 +798,7 @@ public void visit(@NotNull final AggSpecFirst first) {
@Override
public void visit(@NotNull final AggSpecFormula formula) {
unsupportedForBlinkTables("Formula");
// TODO: re-use shared groupBy operators (https://github.com/deephaven/deephaven-core/issues/6363)
final GroupByChunkedOperator groupByChunkedOperator = new GroupByChunkedOperator(table, false, null,
resultPairs.stream().map(pair -> MatchPair.of((Pair) pair.input())).toArray(MatchPair[]::new));
final FormulaChunkedOperator formulaChunkedOperator = new FormulaChunkedOperator(groupByChunkedOperator,
Expand Down Expand Up @@ -860,6 +914,12 @@ default void visit(@NotNull final LastRowKey lastRowKey) {
rollupUnsupported("LastRowKey");
}

@Override
@FinalDefault
default void visit(@NotNull final Formula formula) {
rollupUnsupported("Formula");
}

// -------------------------------------------------------------------------------------------------------------
// AggSpec.Visitor for unsupported column aggregation specs
// -------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,8 @@ public UnaryOperator<ModifiedColumnSet> initializeRefreshing(@NotNull final Quer
resultColumnModifiedColumnSets[ci] = resultTable.newModifiedColumnSet(resultColumnNames[ci]);
}
if (delegateToBy) {
// We cannot use the groupBy's result MCS factory, because the result column names are not guaranteed to be
// the
// same.
// We cannot use the groupBy's result MCS factory, because the result column names are not guaranteed
// to be the same.
groupBy.initializeRefreshing(resultTable, aggregationUpdateListener);
}
// Note that we also use the factory in propagateUpdates to identify the set of modified columns to handle.
Expand Down Expand Up @@ -379,7 +378,7 @@ private class DataFillerContext implements SafeCloseable {
private final boolean[] columnsToFillMask;
final FillFromContext[] fillFromContexts;

private DataFillerContext(@NotNull final boolean[] columnsToFillMask) {
private DataFillerContext(final boolean @NotNull [] columnsToFillMask) {
this.columnsToFillMask = columnsToFillMask;
fillFromContexts = new FillFromContext[resultColumnNames.length];
for (int ci = 0; ci < resultColumnNames.length; ++ci) {
Expand Down Expand Up @@ -448,6 +447,7 @@ private void copyData(@NotNull final RowSequence rowSequence, @NotNull final boo
rowSequenceSlice);
}
}
sharedContext.reset();
}
}
}
Expand Down
Loading
Loading