Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding weights param for combination technique #235

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,39 @@

package org.opensearch.neuralsearch.processor.combination;

import lombok.NoArgsConstructor;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Abstracts combination of scores based on arithmetic mean method
*/
@NoArgsConstructor
public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique {

public static final String TECHNIQUE_NAME = "arithmetic_mean";
public static final String PARAM_NAME_WEIGHTS = "weights";
private static final Set<String> SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS);
private static final Float ZERO_SCORE = 0.0f;
private final List<Float> weights;

public ArithmeticMeanScoreCombinationTechnique(final Map<String, Object> params) {
validateParams(params);
weights = getWeights(params);
}

private List<Float> getWeights(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
return List.of();
}
// get weights, we don't need to check for instance as it's done during validation
return ((List<Double>) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream()
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
.map(Double::floatValue)
.collect(Collectors.toUnmodifiableList());
}

/**
* Arithmetic mean method for combining scores.
Expand All @@ -26,8 +49,11 @@ public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombination
public float combine(final float[] scores) {
float combinedScore = 0.0f;
int count = 0;
for (float score : scores) {
for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) {
float score = scores[indexOfSubQuery];
if (score >= 0.0) {
float weight = getWeightForSubQuery(indexOfSubQuery);
score = score * weight;
combinedScore += score;
count++;
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
}
Expand All @@ -37,4 +63,42 @@ public float combine(final float[] scores) {
}
return combinedScore / count;
}

private void validateParams(final Map<String, Object> params) {
if (Objects.isNull(params) || params.isEmpty()) {
return;
}
// check if only supported params are passed
Optional<String> optionalNotSupportedParam = params.keySet()
.stream()
.filter(paramName -> !SUPPORTED_PARAMS.contains(paramName))
.findFirst();
if (optionalNotSupportedParam.isPresent()) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"provided parameter for combination technique is not supported. supported parameters are [%s]",
SUPPORTED_PARAMS.stream().collect(Collectors.joining(","))
)
);
}

// check param types
if (params.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) {
if (!(params.get(PARAM_NAME_WEIGHTS) instanceof List)) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS)
);
}
}
}

/**
* Get weight for sub-query based on its index in the hybrid search query. Use user provided weight or 1.0 otherwise
* @param indexOfSubQuery 0-based index of sub-query in the Hybrid Search query
* @return weight for sub-query, use one that is set in processor/pipeline definition or 1.0 as default
*/
private float getWeightForSubQuery(int indexOfSubQuery) {
return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery).floatValue() : 1.0f;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@

import java.util.Map;
import java.util.Optional;

import org.opensearch.OpenSearchParseException;
import java.util.function.Function;

/**
* Abstracts creation of exact score combination method based on technique name
*/
public class ScoreCombinationFactory {

public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique();
public static final ScoreCombinationTechnique DEFAULT_METHOD = new ArithmeticMeanScoreCombinationTechnique(Map.of());

private final Map<String, ScoreCombinationTechnique> scoreCombinationMethodsMap = Map.of(
private final Map<String, Function<Map<String, Object>, ScoreCombinationTechnique>> scoreCombinationMethodsMap = Map.of(
ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME,
new ArithmeticMeanScoreCombinationTechnique()
ArithmeticMeanScoreCombinationTechnique::new
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
);

/**
Expand All @@ -28,7 +27,18 @@ public class ScoreCombinationFactory {
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique) {
return createCombination(technique, Map.of());
}

/**
* Get score combination method by technique name
* @param technique name of technique
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
* @param params parameters that combination technique may use
* @return instance of ScoreCombinationTechnique for technique name
*/
public ScoreCombinationTechnique createCombination(final String technique, final Map<String, Object> params) {
return Optional.ofNullable(scoreCombinationMethodsMap.get(technique))
.orElseThrow(() -> new OpenSearchParseException("provided combination technique is not supported"));
.orElseThrow(() -> new IllegalArgumentException("provided combination technique is not supported"))
.apply(params);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
package org.opensearch.neuralsearch.processor.factory;

import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap;
import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty;

import java.util.Map;
import java.util.Objects;

import lombok.AllArgsConstructor;

import org.opensearch.OpenSearchParseException;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
Expand All @@ -29,6 +31,7 @@ public class NormalizationProcessorFactory implements Processor.Factory<SearchPh
public static final String NORMALIZATION_CLAUSE = "normalization";
public static final String COMBINATION_CLAUSE = "combination";
public static final String TECHNIQUE = "technique";
public static final String PARAMETERS = "parameters";

private final NormalizationProcessorWorkflow normalizationProcessorWorkflow;
private ScoreNormalizationFactory scoreNormalizationFactory;
Expand All @@ -54,8 +57,14 @@ public SearchPhaseResultsProcessor create(

ScoreCombinationTechnique scoreCombinationTechnique = ScoreCombinationFactory.DEFAULT_METHOD;
if (Objects.nonNull(combinationClause)) {
String combinationTechnique = (String) combinationClause.getOrDefault(TECHNIQUE, "");
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique);
String combinationTechnique = readOptionalStringProperty(NormalizationProcessor.TYPE, tag, combinationClause, TECHNIQUE);
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
// check for optional combination params
Map<String, Object> combinationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, combinationClause, PARAMETERS);
try {
scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique, combinationParams);
} catch (IllegalArgumentException illegalArgumentException) {
throw new OpenSearchParseException(illegalArgumentException.getMessage(), illegalArgumentException);
}
}

return new NormalizationProcessor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -57,8 +58,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase {

private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000;
private static final String DEFAULT_USER_AGENT = "Kibana";
private static final String NORMALIZATION_METHOD = "min_max";
private static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String NORMALIZATION_METHOD = "min_max";
protected static final String COMBINATION_METHOD = "arithmetic_mean";
protected static final String PARAM_NAME_WEIGHTS = "weights";

protected final ClassLoader classLoader = this.getClass().getClassLoader();

Expand Down Expand Up @@ -566,30 +568,31 @@ protected void createSearchPipeline(
final String pipelineId,
final String normalizationMethod,
String combinationMethod,
final Map<String, Object> combinationParams
final Map<String, String> combinationParams
) {
StringBuilder stringBuilderForContentBody = new StringBuilder();
stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",")
.append("\"phase_results_processors\": [{ ")
.append("\"normalization-processor\": {")
.append("\"normalization\": {")
.append("\"technique\": \"%s\"")
.append("},")
.append("\"combination\": {")
.append("\"technique\": \"%s\"");
if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) {
stringBuilderForContentBody.append(", \"parameters\": {");
if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) {
stringBuilderForContentBody.append("\"weights\": ").append(combinationParams.get(PARAM_NAME_WEIGHTS));
}
stringBuilderForContentBody.append(" }");
}
stringBuilderForContentBody.append("}").append("}}]}");
makeRequest(
client(),
"PUT",
String.format(LOCALE, "/_search/pipeline/%s", pipelineId),
null,
toHttpEntity(
String.format(
LOCALE,
"{\"description\": \"Post processor pipeline\","
+ "\"phase_results_processors\": [{ "
+ "\"normalization-processor\": {"
+ "\"normalization\": {"
+ "\"technique\": \"%s\""
+ "},"
+ "\"combination\": {"
+ "\"technique\": \"%s\""
+ "}"
+ "}}]}",
normalizationMethod,
combinationMethod
)
),
toHttpEntity(String.format(LOCALE, stringBuilderForContentBody.toString(), normalizationMethod, combinationMethod)),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -250,6 +251,92 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf
assertQueryResults(searchResponseAsMap, 4, true);
}

@SneakyThrows
public void testCombinationParams_whenWeightsParamSet_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME);
// check case when number of weights and sub-queries are same
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f }))
);

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4));
hybridQueryBuilder.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7));

Map<String, Object> searchResponseWithWeights1AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights1AsMap, 0.6, 0.5, 0.001);

// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f }))
);

Map<String, Object> searchResponseWithWeights2AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights2AsMap, 2.0, 0.8, 0.001);

// check case when number of weights is less than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f }))
);

Map<String, Object> searchResponseWithWeights3AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 0.8, 0.001);

// check case when number of weights is more than number of sub-queries
// delete existing pipeline and create a new one with another set of weights
deleteSearchPipeline(SEARCH_PIPELINE);
createSearchPipeline(
SEARCH_PIPELINE,
NORMALIZATION_METHOD,
COMBINATION_METHOD,
Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f }))
);

Map<String, Object> searchResponseWithWeights4AsMap = search(
TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", SEARCH_PIPELINE)
);

assertWeightedScores(searchResponseWithWeights4AsMap, 0.6, 0.5, 0.001);
}

private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) {
prepareKnnIndex(
Expand Down Expand Up @@ -402,4 +489,31 @@ private void assertQueryResults(Map<String, Object> searchResponseAsMap, int tot
// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());
}

private void assertWeightedScores(
Map<String, Object> searchResponseWithWeightsAsMap,
double expectedMaxScore,
double expectedMaxMinusOneScore,
double expectedMinScore
) {
assertNotNull(searchResponseWithWeightsAsMap);
Map<String, Object> totalWeights = getTotalHits(searchResponseWithWeightsAsMap);
assertNotNull(totalWeights.get("value"));
assertEquals(4, totalWeights.get("value"));
assertNotNull(totalWeights.get("relation"));
assertEquals(RELATION_EQUAL_TO, totalWeights.get("relation"));
assertTrue(getMaxScore(searchResponseWithWeightsAsMap).isPresent());
assertEquals(expectedMaxScore, getMaxScore(searchResponseWithWeightsAsMap).get(), 0.001f);

List<Double> scoresWeights = new ArrayList<>();
for (Map<String, Object> oneHit : getNestedHits(searchResponseWithWeightsAsMap)) {
scoresWeights.add((Double) oneHit.get("_score"));
}
// verify scores order
assertTrue(IntStream.range(0, scoresWeights.size() - 1).noneMatch(idx -> scoresWeights.get(idx) < scoresWeights.get(idx + 1)));
// verify the scores are normalized with inclusion of weights
assertEquals(expectedMaxScore, scoresWeights.get(0), 0.001);
assertEquals(expectedMaxMinusOneScore, scoresWeights.get(1), 0.001);
assertEquals(expectedMinScore, scoresWeights.get(scoresWeights.size() - 1), 0.001);
}
}
Loading
Loading