diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 0413d7e37..8ac7c63be 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -99,6 +99,6 @@ private void validateParams(final Map params) { * @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default */ private float getWeightForSubQuery(int indexOfSubQuery) { - return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery).floatValue() : 1.0f; + return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java index 019cf606b..bf9bb5c12 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -6,18 +6,19 @@ package org.opensearch.neuralsearch.processor.factory; import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; -import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; import java.util.Map; import java.util.Objects; import lombok.AllArgsConstructor; -import org.opensearch.OpenSearchParseException; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.pipeline.Processor; @@ -49,7 +50,13 @@ public SearchPhaseResultsProcessor create( Map normalizationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, NORMALIZATION_CLAUSE); ScoreNormalizationTechnique normalizationTechnique = ScoreNormalizationFactory.DEFAULT_METHOD; if (Objects.nonNull(normalizationClause)) { - String normalizationTechniqueName = (String) normalizationClause.getOrDefault(TECHNIQUE, ""); + String normalizationTechniqueName = readStringProperty( + NormalizationProcessor.TYPE, + tag, + normalizationClause, + TECHNIQUE, + MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME + ); normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName); } @@ -57,14 +64,16 @@ public SearchPhaseResultsProcessor create( ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD; if (Objects.nonNull(combinationClause)) { - String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE); + String combinationTechnique = readStringProperty( + NormalizationProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME + ); // check for optional combination params Map combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS); - try { - scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams); - } catch (IllegalArgumentException illegalArgumentException) { - throw new OpenSearchParseException(illegalArgumentException.getMessage(), illegalArgumentException); - } + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams); } return new NormalizationProcessor( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 7d5317c14..59b4a0b8b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -20,7 +20,7 @@ */ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique { - protected static final String TECHNIQUE_NAME = "min_max"; + public static final String TECHNIQUE_NAME = "min_max"; private static final float MIN_SCORE = 0.001f; private static final float SINGLE_RESULT_SCORE = 1.0f; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 17bf8cb23..b469e241b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -8,8 +8,6 @@ import java.util.Map; import java.util.Optional; -import org.opensearch.OpenSearchParseException; - /** * Abstracts creation of exact score normalization method based on technique name */ @@ -29,6 +27,6 @@ public class ScoreNormalizationFactory { */ public ScoreNormalizationTechnique createNormalization(final String technique) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new OpenSearchParseException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 7818621ee..83bb0e7bb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -18,7 +18,6 @@ import lombok.SneakyThrows; -import org.opensearch.OpenSearchParseException; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; @@ -64,6 +63,35 @@ public void testNormalizationProcessor_whenNoParams_thenSuccessful() { assertEquals("normalization-processor", normalizationProcessor.getType()); } + @SneakyThrows + public void testNormalizationProcessor_whenTechniqueNamesNotSet_thenSuccessful() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put("normalization", new HashMap<>(Map.of())); + config.put("combination", new HashMap<>(Map.of())); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = normalizationProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof NormalizationProcessor); + NormalizationProcessor normalizationProcessor = (NormalizationProcessor) searchPhaseResultsProcessor; + assertEquals("normalization-processor", normalizationProcessor.getType()); + } + @SneakyThrows public void testNormalizationProcessor_whenWithParams_thenSuccessful() { NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( @@ -113,7 +141,7 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( TECHNIQUE, "arithmetic_mean", PARAMETERS, - new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomFloat(), RandomizedTest.randomFloat()))) + new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble()))) ) ) ); @@ -145,7 +173,7 @@ public void testInputValidation_whenInvalidNormalizationClause_thenFail() { Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); expectThrows( - OpenSearchParseException.class, + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -154,9 +182,9 @@ public void testInputValidation_whenInvalidNormalizationClause_thenFail() { new HashMap<>( Map.of( NormalizationProcessorFactory.NORMALIZATION_CLAUSE, - Map.of(TECHNIQUE, ""), + new HashMap(Map.of(TECHNIQUE, "")), NormalizationProcessorFactory.COMBINATION_CLAUSE, - Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME) + new HashMap(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME)) ) ), pipelineContext @@ -164,7 +192,7 @@ public void testInputValidation_whenInvalidNormalizationClause_thenFail() { ); expectThrows( - OpenSearchParseException.class, + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -196,7 +224,7 @@ public void testInputValidation_whenInvalidCombinationClause_thenFail() { Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); expectThrows( - OpenSearchParseException.class, + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -215,7 +243,7 @@ public void testInputValidation_whenInvalidCombinationClause_thenFail() { ); expectThrows( - OpenSearchParseException.class, + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -246,8 +274,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() { boolean ignoreFailure = false; Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); - OpenSearchParseException exceptionBadTechnique = expectThrows( - OpenSearchParseException.class, + IllegalArgumentException exceptionBadTechnique = expectThrows( + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -273,8 +301,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() { ); assertThat(exceptionBadTechnique.getMessage(), containsString("provided combination technique is not supported")); - OpenSearchParseException exceptionInvalidWeights = expectThrows( - OpenSearchParseException.class, + IllegalArgumentException exceptionInvalidWeights = expectThrows( + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -300,8 +328,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() { ); assertThat(exceptionInvalidWeights.getMessage(), containsString("parameter [weights] must be a collection of numbers")); - OpenSearchParseException exceptionInvalidWeights2 = expectThrows( - OpenSearchParseException.class, + IllegalArgumentException exceptionInvalidWeights2 = expectThrows( + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag, @@ -327,8 +355,8 @@ public void testInputValidation_whenInvalidCombinationParams_thenFail() { ); assertThat(exceptionInvalidWeights2.getMessage(), containsString("parameter [weights] must be a collection of numbers")); - OpenSearchParseException exceptionInvalidParam = expectThrows( - OpenSearchParseException.class, + IllegalArgumentException exceptionInvalidParam = expectThrows( + IllegalArgumentException.class, () -> normalizationProcessorFactory.create( processorFactories, tag,