Skip to content

Commit

Permalink
Fixes #2744, support onnx model for TextEmbeddingTranslator (#2749)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanktliu authored Aug 14, 2023
1 parent 17bfda1 commit 5f39a4c
Showing 1 changed file with 24 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ public class TextEmbeddingTranslator implements Translator<String, float[]> {
private Batchifier batchifier;
private boolean normalize;
private String pooling;
private boolean includeTokenTypes;

TextEmbeddingTranslator(
HuggingFaceTokenizer tokenizer,
Batchifier batchifier,
String pooling,
boolean normalize) {
boolean normalize,
boolean includeTokenTypes) {
this.tokenizer = tokenizer;
this.batchifier = batchifier;
this.pooling = pooling;
this.normalize = normalize;
this.includeTokenTypes = includeTokenTypes;
}

/** {@inheritDoc} */
Expand All @@ -58,7 +61,7 @@ public Batchifier getBatchifier() {
public NDList processInput(TranslatorContext ctx, String input) {
Encoding encoding = tokenizer.encode(input);
ctx.setAttachment("encoding", encoding);
return encoding.toNDList(ctx.getNDManager(), false);
return encoding.toNDList(ctx.getNDManager(), includeTokenTypes);
}

/** {@inheritDoc} */
Expand All @@ -84,6 +87,10 @@ public TextEmbeddingBatchTranslator toBatchTranslator(Batchifier batchifier) {
static NDArray processEmbedding(
NDManager manager, NDList list, Encoding encoding, String pooling) {
NDArray embedding = list.get("last_hidden_state");
if (embedding == null) {
// For Onnx model, NDArray name is not present
embedding = list.head();
}
long[] attentionMask = encoding.getAttentionMask();
NDArray inputAttentionMask = manager.create(attentionMask).toType(DataType.FLOAT32, true);
switch (pooling) {
Expand Down Expand Up @@ -167,6 +174,7 @@ public static final class Builder {
private Batchifier batchifier = Batchifier.STACK;
private boolean normalize = true;
private String pooling = "mean";
private boolean includeTokenTypes;

Builder(HuggingFaceTokenizer tokenizer) {
this.tokenizer = tokenizer;
Expand Down Expand Up @@ -214,6 +222,17 @@ public Builder optPoolingMode(String poolingMode) {
return this;
}

/**
* Sets if include token types for the {@link Translator}.
*
* @param includeTokenTypes true to include token types
* @return this builder
*/
public Builder optIncludeTokenTypes(boolean includeTokenTypes) {
this.includeTokenTypes = includeTokenTypes;
return this;
}

/**
* Configures the builder with the model arguments.
*
Expand All @@ -224,6 +243,7 @@ public void configure(Map<String, ?> arguments) {
optBatchifier(Batchifier.fromString(batchifierStr));
optNormalize(ArgumentsUtil.booleanValue(arguments, "normalize", true));
optPoolingMode(ArgumentsUtil.stringValue(arguments, "pooling", "mean"));
optIncludeTokenTypes(ArgumentsUtil.booleanValue(arguments, "includeTokenTypes"));
}

/**
Expand All @@ -233,7 +253,8 @@ public void configure(Map<String, ?> arguments) {
* @throws IOException if I/O error occurs
*/
public TextEmbeddingTranslator build() throws IOException {
return new TextEmbeddingTranslator(tokenizer, batchifier, pooling, normalize);
return new TextEmbeddingTranslator(
tokenizer, batchifier, pooling, normalize, includeTokenTypes);
}
}
}

0 comments on commit 5f39a4c

Please sign in to comment.