Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Feb 21, 2024
1 parent cb2a498 commit e93c058
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 6 deletions.
87 changes: 82 additions & 5 deletions docs/source/tutorials/sentence_transformers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ limitations under the License.
-->
# Sentence Transformers on AWS Inferentia with Optimum Neuron

## Text Models

_There is a notebook version of that tutorial [here](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/sentence-transformers/getting-started.ipynb)._

This guide explains how to compile, load, and use [Sentence Transformers (SBERT)](https://www.sbert.net/) models on AWS Inferentia2 with Optimum Neuron, enabling efficient calculation of embeddings. Sentence Transformers are powerful models for generating sentence embeddings. You can use this Sentence Transformers to compute sentence / text embeddings for more than 100 languages. These embeddings can then be compared e.g. with cosine-similarity to find sentences with a similar meaning. This can be useful for semantic textual similarity, semantic search, or paraphrase mining.

_Note: Currently only text models are supported, we are working on vision support for CLIP._


## Convert Sentence Transformers model to AWS Inferentia2
### Convert Sentence Transformers model to AWS Inferentia2

First, you need to convert your Sentence Transformers model to a format compatible with AWS Inferentia2. You can compile Sentence Transformers models with Optimum Neuron using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Below you will find an example for both approaches. We have to make sure `sentence-transformers` is installed. Thats only needed for exporting the model.

Expand Down Expand Up @@ -52,7 +52,7 @@ Here we will use the `optimum-cli` to convert the model. Similar to the `NeuronM
optimum-cli export neuron -m BAAI/bge-small-en-v1.5 --library-name sentence_transformers --sequence_length 384 --batch_size 1 --task feature-extraction bge_emb_inf2/
```

## Load compiled Sentence Transformers model and run inference
### Load compiled Sentence Transformers model and run inference

Once we have a compiled Sentence Transformers model, which we either exported ourselves or is available on the Hugging Face Hub, we can load it and run inference. For loading the model we can use the `NeuronModelForSentenceTransformers` class, which is an abstraction layer for the `SentenceTransformer` class. The `NeuronModelForSentenceTransformers` class will automatically pad the input to the specified `sequence_length` and run inference on AWS Inferentia2.

Expand All @@ -79,6 +79,83 @@ print(f"token embeddings: {token_embeddings.shape}") # torch.Size([1, 7, 384])
print(f"sentence_embedding: {sentence_embedding.shape}") # torch.Size([1, 384])
```
## Production Usage
### Production Usage
For deploying these models in a production environment, refer to the [Amazon SageMaker Blog](https://www.philschmid.de/inferentia2-embeddings).
## CLIP
### Compile CLIP for AWS Inferentia2
You can compile CLIP models with Optimum Neuron either by using the `optimum-cli` or `NeuronModelForSentenceTransformers` class. Adopt one approache that you prefer:
* With CLI
```bash
optimum-cli export neuron -m sentence-transformers/clip-ViT-B-32 --sequence_length 64 --batch_size 3 --num_channels 3 --height 224 --width 224 --task feature-extraction --library-name sentence_transformers --subfolder 0_CLIPModel clip_emb/
```
* With `NeuronModelForSentenceTransformers` class
```python
from optimum.neuron import NeuronModelForSentenceTransformers
# [Compile]
model_id = "sentence-transformers/clip-ViT-B-32"

# configs for compiling model
input_shapes = {
"num_channels": 3,
"height": 224,
"width": 224,
"batch_size": 1,
"sequence_length": 64,
}

emb_model = NeuronModelForSentenceTransformers.from_pretrained(
model_id, subfolder="0_CLIPModel", export=True, library_name="sentence_transformers", **input_shapes
)

# Save locally or upload to the HuggingFace Hub
save_directory = "clip_emb"
emb_model.save_pretrained(save_directory)
```

### Load compiled Sentence Transformers model and run inference

```
from PIL import Image
from sentence_transformers import util
from transformers import CLIPProcessor

from optimum.neuron import NeuronModelForSentenceTransformers

save_directory = "clip_emb"
emb_model = NeuronModelForSentenceTransformers.from_pretrained(save_directory)

processor = CLIPProcessor.from_pretrained(save_directory)
inputs = processor(
text=["Two dogs in the snow", 'A cat on a table', 'A picture of London at night'], images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True
)

outputs = emb_model(**inputs)


# Compute cosine similarities
cos_scores = util.cos_sim(outputs.image_embeds, outputs.text_embeds)
print(cos_scores)

# tensor([[0.3072, 0.1016, 0.1095]])
```

<Tip>

**Caveat**

Since the compiled model only accept tensors with the same batch size, `NeuronModelForSentenceTransformers` class will pad the inputs with the largest batch size of all entry tensors and remove the padding at the output. In this case, you might want to use a large batch_size (at least the `max(image_batch_size, text_batch_size)`) and disable `dynamic_batch_size`.

eg. In the example above, we want to compute the similarities for 3 texts with 1 image, in this case, the batch_size should be set as at least 3 during the compilation.

</Tip>
8 changes: 7 additions & 1 deletion optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,13 @@ def forward(
with self.neuron_padding_manager(neuron_inputs) as inputs:
outputs = self.model(*inputs)
if "clip" in model_type:
return ModelOutput(text_embeds=outputs[0], image_embeds=outputs[1])
text_embeds = self.remove_padding([outputs[0]], dims=[0], indices=[input_ids.shape[0]])[
0
] # Remove padding on batch_size(0)
image_embeds = self.remove_padding([outputs[1]], dims=[0], indices=[pixel_values.shape[0]])[
0
] # Remove padding on batch_size(0)
return ModelOutput(text_embeds=text_embeds, image_embeds=image_embeds)
else:
# token_embeddings -> (batch_size, sequencen_len, hidden_size)
token_embeddings = self.remove_padding(
Expand Down

0 comments on commit e93c058

Please sign in to comment.