Skip to content

Commit

Permalink
minor changes based on comments
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Sep 27, 2023
1 parent 508b462 commit aae62d4
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 17 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results.
*/
@Log4j2
public class SparseEncodingProcessor extends NLPProcessor {
public final class SparseEncodingProcessor extends NLPProcessor {

public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
public class TextEmbeddingProcessor extends NLPProcessor {
public final class TextEmbeddingProcessor extends NLPProcessor {

public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;

/**
* Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
@Log4j2
public class SparseEncodingProcessorFactory implements Processor.Factory {
private final MLCommonsClientAccessor clientAccessor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;

/**
* Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input.
*/
public class TextEmbeddingProcessorFactory implements Processor.Factory {

private final MLCommonsClientAccessor clientAccessor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,24 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr
if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new ParsingException(
parser.getTokenLocation(),
"["
+ NAME
+ "] query doesn't support multiple fields, found ["
+ sparseEncodingQueryBuilder.fieldName()
+ "] and ["
+ parser.currentName()
+ "]"
String.format(
"[%s] query doesn't support multiple fields, found [%s] and [%s]",
NAME,
sparseEncodingQueryBuilder.fieldName(),
parser.currentName()
)
);
}

requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query");
requireValue(
sparseEncodingQueryBuilder.queryText(),
QUERY_TEXT_FIELD.getPreferredName() + " must be provided for " + NAME + " query"
String.format("%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME)
);
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format("%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
);
requireValue(sparseEncodingQueryBuilder.modelId(), MODEL_ID_FIELD.getPreferredName() + " must be provided for " + NAME + " query");

return sparseEncodingQueryBuilder;
}
Expand All @@ -164,13 +166,13 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] query does not support [" + currentFieldName + "]"
String.format("[%s] query does not support [%s] field", NAME, currentFieldName)
);
}
} else {
throw new ParsingException(
parser.getTokenLocation(),
"[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]"
String.format("[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName)
);
}
}
Expand Down Expand Up @@ -220,7 +222,7 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
private static void validateForRewrite(String queryText, String modelId) {
if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) {
throw new IllegalArgumentException(
QUERY_TEXT_FIELD.getPreferredName() + " and " + MODEL_ID_FIELD.getPreferredName() + " cannot be null."
String.format("%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName())
);
}
}
Expand All @@ -238,7 +240,7 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
if (entry.getValue() <= 0) {
throw new IllegalArgumentException(
"Feature weight must be larger than 0, got: " + entry.getValue() + "for key " + entry.getKey()
"Feature weight must be larger than 0, feature [" + entry.getValue() + "] has negative weight."
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ public class TokenWeightUtil {
* @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor}
*/
public static List<Map<String, Float>> fetchListOfTokenWeightMap(List<Map<String, ?>> mapResultList) {
if (null == mapResultList || mapResultList.isEmpty()) {
throw new IllegalArgumentException("The inference result can not be null or empty.");
}
List<Object> results = new ArrayList<>();
for (Map<String, ?> map : mapResultList) {
if (!map.containsKey(RESPONSE_KEY)) {
Expand All @@ -66,7 +69,7 @@ private static Map<String, Float> buildTokenWeightMap(Object uncastedMap) {
Map<String, Float> result = new HashMap<>();
for (Map.Entry<?, ?> entry : ((Map<?, ?>) uncastedMap).entrySet()) {
if (!String.class.isAssignableFrom(entry.getKey().getClass()) || !Number.class.isAssignableFrom(entry.getValue().getClass())) {
throw new IllegalArgumentException("The expected inference result is a Map with String keys and " + " Float values.");
throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values.");
}
result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue());
}
Expand Down

0 comments on commit aae62d4

Please sign in to comment.