Skip to content

Commit

Permalink
fix: correct weighted summation null handling behavior (#5660)
Browse files Browse the repository at this point in the history
* Corrects null handling issue and potential data truncation.

* Another logic correction and renaming misleading variables.
  • Loading branch information
lbooker42 committed Jun 24, 2024
1 parent 31a9895 commit 91d5df5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.deephaven.chunk.attributes.Values;
import io.deephaven.chunk.*;
import io.deephaven.util.mutable.MutableInt;
import io.deephaven.util.mutable.MutableLong;

import java.util.Collections;
import java.util.Map;
Expand Down Expand Up @@ -46,12 +47,12 @@ public void addChunk(BucketedContext bucketedContext, Chunk<? extends Values> va
IntChunk<ChunkPositions> startPositions, IntChunk<ChunkLengths> length,
WritableBooleanChunk<Values> stateModified) {
final Context context = (Context) bucketedContext;
final LongChunk<? extends Values> doubleValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getAddedWeights();
Assert.neqNull(weightValues, "weightValues");
for (int ii = 0; ii < startPositions.size(); ++ii) {
final int startPosition = startPositions.get(ii);
stateModified.set(ii, addChunk(doubleValues, weightValues, startPosition, length.get(ii),
stateModified.set(ii, addChunk(longValues, weightValues, startPosition, length.get(ii),
destinations.get(startPosition)));
}
}
Expand All @@ -62,12 +63,12 @@ public void removeChunk(BucketedContext bucketedContext, Chunk<? extends Values>
IntChunk<ChunkPositions> startPositions, IntChunk<ChunkLengths> length,
WritableBooleanChunk<Values> stateModified) {
final Context context = (Context) bucketedContext;
final LongChunk<? extends Values> doubleValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getRemovedWeights();
Assert.neqNull(weightValues, "weightValues");
for (int ii = 0; ii < startPositions.size(); ++ii) {
final int startPosition = startPositions.get(ii);
stateModified.set(ii, removeChunk(doubleValues, weightValues, startPosition, length.get(ii),
stateModified.set(ii, removeChunk(longValues, weightValues, startPosition, length.get(ii),
destinations.get(startPosition)));
}
}
Expand All @@ -93,18 +94,18 @@ public void modifyChunk(BucketedContext bucketedContext, Chunk<? extends Values>
public boolean addChunk(SingletonContext singletonContext, int chunkSize, Chunk<? extends Values> values,
LongChunk<? extends RowKeys> inputRowKeys, long destination) {
final Context context = (Context) singletonContext;
final LongChunk<? extends Values> doubleValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.toLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getAddedWeights();
return addChunk(doubleValues, weightValues, 0, values.size(), destination);
return addChunk(longValues, weightValues, 0, values.size(), destination);
}

@Override
public boolean removeChunk(SingletonContext singletonContext, int chunkSize, Chunk<? extends Values> values,
LongChunk<? extends RowKeys> inputRowKeys, long destination) {
final Context context = (Context) singletonContext;
final LongChunk<? extends Values> doubleValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> longValues = context.prevToLongCast.apply(values);
final LongChunk<? extends Values> weightValues = weightOperator.getRemovedWeights();
return removeChunk(doubleValues, weightValues, 0, values.size(), destination);
return removeChunk(longValues, weightValues, 0, values.size(), destination);
}

@Override
Expand All @@ -121,19 +122,19 @@ public boolean modifyChunk(SingletonContext singletonContext, int chunkSize, Chu
newDoubleValues.size(), destination);
}

private static void sumChunks(LongChunk<? extends Values> doubleValues, LongChunk<? extends Values> weightValues,
private static void sumChunks(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start,
int length,
MutableInt normalOut,
MutableInt weightedSumOut) {
MutableLong weightedSumOut) {
int normal = 0;
int weightedSum = 0;
long weightedSum = 0;

for (int ii = 0; ii < length; ++ii) {
final double weight = weightValues.get(start + ii);
final double component = doubleValues.get(start + ii);
final long weight = weightValues.get(start + ii);
final long component = longValues.get(start + ii);

if (weight == QueryConstants.NULL_DOUBLE || component == QueryConstants.NULL_DOUBLE) {
if (weight == QueryConstants.NULL_LONG || component == QueryConstants.NULL_LONG) {
continue;
}

Expand All @@ -148,12 +149,12 @@ private static void sumChunks(LongChunk<? extends Values> doubleValues, LongChun
private boolean addChunk(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(longValues, weightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand All @@ -171,21 +172,21 @@ private boolean addChunk(LongChunk<? extends Values> longValues, LongChunk<? ext
weightedSum.set(destination, totalWeightedSum);
}

final double existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
final long existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
return totalWeightedSum != existingResult;
}
return false;
}

private boolean removeChunk(LongChunk<? extends Values> doubleValues, LongChunk<? extends Values> weightValues,
private boolean removeChunk(LongChunk<? extends Values> longValues, LongChunk<? extends Values> weightValues,
int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(doubleValues, weightValues, start, length, normalOut, weightedSumOut);
sumChunks(longValues, weightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand Down Expand Up @@ -226,17 +227,17 @@ private boolean modifyChunk(LongChunk<? extends Values> prevDoubleValues,
LongChunk<? extends Values> prevWeightValues, LongChunk<? extends Values> newDoubleValues,
LongChunk<? extends Values> newWeightValues, int start, int length, long destination) {
final MutableInt normalOut = new MutableInt();
final MutableInt weightedSumOut = new MutableInt();
final MutableLong weightedSumOut = new MutableLong();

sumChunks(prevDoubleValues, prevWeightValues, start, length, normalOut, weightedSumOut);

final int prevNormal = normalOut.get();
final int prevWeightedSum = weightedSumOut.get();
final long prevWeightedSum = weightedSumOut.get();

sumChunks(newDoubleValues, newWeightValues, start, length, normalOut, weightedSumOut);

final int newNormal = normalOut.get();
final int newWeightedSum = weightedSumOut.get();
final long newWeightedSum = weightedSumOut.get();

final long totalNormal;
final long existingNormal = normalCount.getUnsafe(destination);
Expand All @@ -255,12 +256,12 @@ private boolean modifyChunk(LongChunk<? extends Values> prevDoubleValues,
weightedSum.set(destination, totalWeightedSum);
}

final double existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
final long existingResult = resultColumn.getAndSetUnsafe(destination, totalWeightedSum);
return totalWeightedSum != existingResult;
} else {
if (prevNormal > 0) {
weightedSum.set(destination, 0L);
resultColumn.set(destination, QueryConstants.NULL_DOUBLE);
resultColumn.set(destination, QueryConstants.NULL_LONG);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2248,6 +2248,40 @@ private void testWeightedAvgByIncremental(int size, int seed) {

}

@Test
public void testWeightedSumByLong() {
final QueryTable table = testRefreshingTable(i(2, 4, 6).toTracking(),
col("Long1", 2L, 4L, 6L), col("Long2", 1L, 2L, 3L));
final Table result = table.wsumBy("Long2");
TableTools.show(result);
TestCase.assertEquals(1, result.size());
long result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
long wsum = 2 + 8 + 18;
TestCase.assertEquals(wsum, result_wsum);

final ControlledUpdateGraph updateGraph = ExecutionContext.getContext().getUpdateGraph().cast();
updateGraph.runWithinUnitTestCycle(() -> {
addToTable(table, i(8), col("Long1", (long) Integer.MAX_VALUE), col("Long2", 7L));
table.notifyListeners(i(8), i(), i());
});
show(result);
result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
wsum = wsum + (7L * (long) Integer.MAX_VALUE);
TestCase.assertEquals(wsum, result_wsum);
}

@Test
public void testId5522() {
final QueryTable table = testRefreshingTable(i(2, 4, 6).toTracking(),
col("Long1", 10L, 20L, 30L), col("Long2", 1L, NULL_LONG, 1L));
final Table result = table.wsumBy("Long2");
TableTools.show(result);
TestCase.assertEquals(1, result.size());
long result_wsum = result.getColumnSource("Long1", long.class).getLong(result.getRowSet().firstRowKey());
long wsum = 10 + 30;
TestCase.assertEquals(wsum, result_wsum);
}

@Test
public void testWeightedSumByIncremental() {
final int[] sizes = {10, 50, 200};
Expand Down

0 comments on commit 91d5df5

Please sign in to comment.