Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cross encoder support #1615

Merged
merged 18 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions common/src/main/java/org/opensearch/ml/common/FunctionName.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.common;

import java.util.HashSet;
import java.util.Set;

public enum FunctionName {
LINEAR_REGRESSION,
KMEANS,
Expand All @@ -17,6 +20,7 @@
RCF_SUMMARIZE,
LOGISTIC_REGRESSION,
TEXT_EMBEDDING,
TEXT_SIMILARITY,
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
SPARSE_ENCODING,
SPARSE_TOKENIZE,
METRICS_CORRELATION,
Expand All @@ -30,14 +34,18 @@
}
}

private static final HashSet<FunctionName> DL_MODELS = new HashSet<>(Set.of(
TEXT_EMBEDDING,
TEXT_SIMILARITY,
SPARSE_ENCODING,
SPARSE_TOKENIZE
));

/**
* Check if model is deep learning model.
* @return true for deep learning model.
*/
public static boolean isDLModel(FunctionName functionName) {
if (functionName == TEXT_EMBEDDING || functionName == SPARSE_ENCODING || functionName == SPARSE_TOKENIZE) {
return true;
}
return false;
return DL_MODELS.contains(functionName);

Check warning on line 49 in common/src/main/java/org/opensearch/ml/common/FunctionName.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/FunctionName.java#L49

Added line #L49 was not covered by tests
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME,
TEXT_DOCS,
TEXT_SIMILARITY,
REMOTE
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2023 Aryn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylwu-amzn did we wrap up this discussion about adding Aryn in the code base?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as I understand, all this does is protect us on the off-chance that AWS does what Elastic did a few years ago and switches out the license or something. This is coming from Mehul, who oversaw that whole transition (elastic fork -> opensearch).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protect us on the off-chance that AWS does what Elastic did a few years ago and switches out the license or something

No need to worry that actually, if you check Elastic, community can still use the open source version even they changed license, otherwise we can't do "elastic fork -> opensearch". If you worry this part, I think we should suggest not adding Aryn to license header. What if some other company make a change to this file, they may also prefer to add their license. Then in future how can Aryn tell which company holds the license for this file? It looks not reasonable that not allowing other guys modifying your code, right?

I guess the other reason for adding Aryn is to show the credit of this feature. We are planning to build some way to show attribution/appreciation for features from community contributors. For example, we can maintain a contribution list file like this

Feature Description Contributor
Conversation search This feature will... Austin (from Aryn.ai) Henry (from Aryn.ai)

Any suggestion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we already decided this with @sean-zheng-amazon and @mashah

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HenryL27 , Can you please point to the discussions where this is decided and what exactly is the decision?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was decided prior to the 2.10 release as part of a PR that we contributed. I believe the initial discussion was done in that PR, but was later concluded over a call.

* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.TEXT_SIMILARITY)
public class TextSimilarityInputDataSet extends MLInputDataset {

List<String> textDocs;

String queryText;

@Builder(toBuilder = true)
public TextSimilarityInputDataSet(String queryText, List<String> textDocs) {
super(MLInputDataType.TEXT_SIMILARITY);
Objects.requireNonNull(textDocs);
Objects.requireNonNull(queryText);
if(textDocs.isEmpty()) {
throw new IllegalArgumentException("No text documents provided");
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
}
this.textDocs = textDocs;
this.queryText = queryText;
}

public TextSimilarityInputDataSet(StreamInput in) throws IOException {
super(MLInputDataType.TEXT_SIMILARITY);
this.queryText = in.readString();
int size = in.readInt();
this.textDocs = new ArrayList<String>();
for(int i = 0; i < size; i++) {
String context = in.readString();
this.textDocs.add(context);
}
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(queryText);
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
out.writeInt(this.textDocs.size());
for (String doc : this.textDocs) {
out.writeString(doc);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand Down Expand Up @@ -55,6 +57,8 @@
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";
// Input query text to compare against for text similarity model
public static final String QUERY_TEXT_FIELD = "query_text";

// Algorithm name
protected FunctionName algorithm;
Expand Down Expand Up @@ -157,6 +161,20 @@
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
break;
case TEXT_SIMILARITY:
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> tdocs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (tdocs != null && !tdocs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : tdocs) {
builder.value(d);
}
builder.endArray();
}
break;
default:
break;
}
Expand Down Expand Up @@ -186,6 +204,7 @@
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
List<String> textDocs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -233,6 +252,9 @@
textDocs.add(parser.text());
}
break;
case QUERY_TEXT_FIELD:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

queryText = parser.text();
break;

Check warning on line 257 in common/src/main/java/org/opensearch/ml/common/input/MLInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/input/MLInput.java#L256-L257

Added lines #L256 - L257 were not covered by tests
default:
parser.skipChildren();
break;
Expand All @@ -243,6 +265,9 @@
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
if (algorithm == FunctionName.TEXT_SIMILARITY) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have test for this section?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os jacoco is showing red for like, this entire section (despite my attempts to the contrary). I think what's happening is that the classLoader stuff at the top of this method preempts all the MLInput parsing logic and relegates it to the subclass. We should probly just remove this stuff, idk?

inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);

Check warning on line 269 in common/src/main/java/org/opensearch/ml/common/input/MLInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/input/MLInput.java#L269

Added line #L269 was not covered by tests
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.input.nlp;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;


/**
* MLInput which supports a text similarity algorithm
* Inputs are a query and a list of texts. Outputs are real numbers
* Use this for Cross Encoder models
*/
@org.opensearch.ml.common.annotation.MLInput(functionNames = {FunctionName.TEXT_SIMILARITY})
public class TextSimilarityMLInput extends MLInput {

public TextSimilarityMLInput(FunctionName algorithm, MLInputDataset dataset) {
super(algorithm, null, dataset);
}

public TextSimilarityMLInput(StreamInput in) throws IOException {
super(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALGORITHM_FIELD, algorithm.name());
if(parameters != null) {
builder.field(ML_PARAMETERS_FIELD, parameters);

Check warning on line 62 in common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java#L62

Added line #L62 was not covered by tests
}
if(inputDataset != null) {
TextSimilarityInputDataSet ds = (TextSimilarityInputDataSet) this.inputDataset;
List<String> docs = ds.getTextDocs();
String queryText = ds.getQueryText();
builder.field(QUERY_TEXT_FIELD, queryText);
if (docs != null && !docs.isEmpty()) {
builder.startArray(TEXT_DOCS_FIELD);
for(String d : docs) {
builder.value(d);
}
builder.endArray();
}
}
builder.endObject();
return builder;
}

public TextSimilarityMLInput(XContentParser parser, FunctionName functionName) throws IOException {
super();
this.algorithm = functionName;
List<String> docs = new ArrayList<>();
String queryText = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
String context = parser.text();
docs.add(context);
}
break;
case QUERY_TEXT_FIELD:
queryText = parser.text();
default:
parser.skipChildren();
break;
}
}
if(docs.isEmpty()) {
throw new IllegalArgumentException("No text documents were provided");
}
if(queryText == null) {
throw new IllegalArgumentException("No query text was provided");

Check warning on line 111 in common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java

View check run for this annotation

Codecov / codecov/patch

common/src/main/java/org/opensearch/ml/common/input/nlp/TextSimilarityMLInput.java#L111

Added line #L111 was not covered by tests
}
inputDataset = new TextSimilarityInputDataSet(queryText, docs);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.common.dataset;

import static org.junit.Assert.assertThrows;

import java.io.IOException;
import java.util.List;

import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.BytesStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

public class TextSimilarityInputDatasetTest {

@Test
public void testStreaming() throws IOException {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = "today is sunny";
TextSimilarityInputDataSet dataset = TextSimilarityInputDataSet.builder().queryText(queryText).textDocs(docs).build();
BytesStreamOutput outbytes = new BytesStreamOutput();
StreamOutput osso = new OutputStreamStreamOutput(outbytes);
dataset.writeTo(osso);
StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes()));
TextSimilarityInputDataSet newDs = (TextSimilarityInputDataSet) MLInputDataset.fromStream(in);
assert (dataset.getTextDocs().equals(newDs.getTextDocs()));
assert (dataset.getQueryText().equals(newDs.getQueryText()));
}

@Test
public void noPairs_ThenFail() {
List<String> docs = List.of();
String queryText = "today is sunny";
IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
assert (e.getMessage().equals("No text documents provided"));
}

@Test
public void noQuery_ThenFail() {
List<String> docs = List.of("That is a happy dog", "it's summer");
String queryText = null;
assertThrows(NullPointerException.class,
() -> TextSimilarityInputDataSet.builder().textDocs(docs).queryText(queryText).build());
}
}
Loading
Loading