-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
28 lines (20 loc) · 1.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# pip install accelerate
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", device_map="auto")
input_text = "What is FAANG?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
print("---------------------", input_ids)
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
task_prefix = "translate English to Russian: "
# use different length sentences to test batching
sentences = ["The house is wonderful.", "I like to work in NYC."]
inputs = tokenizer([task_prefix + sentence for sentence in sentences], return_tensors="pt", padding=True).to("cuda")
print("--------------------", inputs)
output_sequences = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
#do_sample=False, # disable sampling to test if batching affects output
)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))