diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 3095cbbd2..842df736d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -32,17 +32,6 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } - public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); - ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, weights), - scoreCombinationUtil - ); - float expectedScore = 0.6722f; - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); - } - public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { List weights = IntStream.range(0, RANDOM_SCORES_SIZE) .mapToObj(i -> RandomizedTest.randomDouble()) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index 4f6c6f0cf..fe0d962ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -76,17 +76,23 @@ public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); } + /** + * Verify score correctness by using alternative formula for geometric mean as n-th root of product of weighted scores, + * more details in here https://en.wikipedia.org/wiki/Weighted_geometric_mean + */ private float geometricMean(List scores, List weights) { - assertEquals(scores.size(), weights.size()); - float sumOfWeights = 0; - float weightedSumOfLn = 0; - for (int i = 0; i < scores.size(); i++) { - float score = scores.get(i), weight = weights.get(i).floatValue(); - if (score > 0) { - sumOfWeights += weight; - weightedSumOfLn += weight * Math.log(score); + float product = 1.0f; + float sumOfWeights = 0.0f; + for (int indexOfSubQuery = 0; indexOfSubQuery < scores.size(); indexOfSubQuery++) { + float score = scores.get(indexOfSubQuery); + if (score <= 0) { + // scores 0.0 need to be skipped, ln() of 0 is not defined + continue; } + float weight = weights.get(indexOfSubQuery).floatValue(); + product *= Math.pow(score, weight); + sumOfWeights += weight; } - return sumOfWeights == 0 ? 0f : (float) Math.exp(weightedSumOfLn / sumOfWeights); + return sumOfWeights == 0 ? 0f : (float) Math.pow(product, (float) 1 / sumOfWeights); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 651326af7..8187123a1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -43,17 +43,6 @@ public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); } - public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); - ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( - Map.of(PARAM_NAME_WEIGHTS, weights), - scoreCombinationUtil - ); - testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, weights); - } - public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, -1.0f, 0.6f); List weights = List.of(0.9, 0.2, 0.7);