Skip to content

Commit

Permalink
Schema updates and embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Sep 23, 2024
1 parent ab6161a commit 0529f6e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,11 @@ class Song(BaseModel):
result = asst.chat_completion([
{"role": "user", "content": "What is Britney Spears's best song?"}
], schema=Song, max_new_tokens=16)
], schema=Song)
# Song(name='Toxic', artist='Britney Spears')
# Embeddings
asst.embed(["hello, world!", "this will be embedded"])
```

### Tokenization
Expand Down
15 changes: 14 additions & 1 deletion demos/Text2Text_LLM.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,26 @@
"\n",
"result = asst.chat_completion([\n",
" {\"role\": \"user\", \"content\": \"What is Britney Spears's best song?\"}\n",
"], schema=Song, max_new_tokens=16)"
"], schema=Song)\n",
"print(result) #Song(name='Toxic', artist='Britney Spears')"
],
"metadata": {
"id": "e5khHlNQZ0FD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Embeddings\n",
"asst.embed([\"hello, world!\", \"this will be embedded\"])"
],
"metadata": {
"id": "WJX2klusQR9q"
},
"execution_count": null,
"outputs": []
}
]
}
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.5.2",
version="1.5.3",
author="artitw",
author_email="[email protected]",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand Down
7 changes: 5 additions & 2 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self, **kwargs):
self.model_name = kwargs.get("model_name", "llama3.1")
self.load_model()
self.client = ollama.Client(host=self.model_url)
self.llama_index_client = Ollama(model=self.model_name, request_timeout=120.0)

def __del__(self):
ollama.delete(self.model_name)
Expand All @@ -46,12 +47,14 @@ def chat_completion(self, messages=[{"role": "user", "content": "hello"}], strea
if is_port_in_use(self.port):
if schema:
msgs = [ChatMessage(**m) for m in messages]
llama_index_client = Ollama(model=self.model_name, request_timeout=120.0)
return llama_index_client.as_structured_llm(schema).chat(messages=msgs).raw
return self.llama_index_client.as_structured_llm(schema).chat(messages=msgs).raw
return self.client.chat(model=self.model_name, messages=messages, stream=stream)
self.load_model()
return self.chat_completion(messages=messages, stream=stream, **kwargs)

def embed(self, texts):
return ollama.embed(model=self.model_name, input=texts)

def transform(self, input_lines, src_lang='en', **kwargs):
return self.chat_completion([{"role": "user", "content": input_lines}])["message"]["content"]

Expand Down

0 comments on commit 0529f6e

Please sign in to comment.