From 8357d8d8a267b407774e617db2757f6aa675fd00 Mon Sep 17 00:00:00 2001 From: artitw Date: Sun, 22 Sep 2024 23:50:46 +0000 Subject: [PATCH] Add Text2Text Assistant as LLM for LangChain --- setup.py | 2 +- text2text/assistant.py | 27 ++++++++++++------- text2text/pytorch_pretrained_bert/modeling.py | 2 -- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 9a8da5c..32a57dd 100755 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="text2text", - version="1.5.1", + version="1.5.2", author="artitw", author_email="artitw@gmail.com", description="Text2Text: Crosslingual NLP/G toolkit", diff --git a/text2text/assistant.py b/text2text/assistant.py index 2d8fc49..a777dc8 100644 --- a/text2text/assistant.py +++ b/text2text/assistant.py @@ -2,6 +2,7 @@ import ollama import psutil import subprocess +import time from llama_index.llms.ollama import Ollama from llama_index.core.llms import ChatMessage @@ -19,19 +20,27 @@ def __init__(self, **kwargs): self.port = kwargs.get("port", 11434) self.model_url = f"{self.host}:{self.port}" self.model_name = kwargs.get("model_name", "llama3.1") - return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh") - if return_code != 0: - print("Cannot install ollama.") - return_code = os.system("sudo systemctl enable ollama") self.load_model() self.client = ollama.Client(host=self.model_url) + def __del__(self): + ollama.delete(self.model_name) + def load_model(self): - sub = subprocess.Popen( - f"ollama serve & ollama pull {self.model_name} & ollama run {self.model_name}", - shell=True, - stdout=subprocess.PIPE - ) + return_code = os.system("sudo apt install -q -y lshw") + if return_code != 0: + print("Cannot install lshw.") + return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh") + if return_code != 0: + print("Cannot install ollama.") + return_code = os.system("sudo systemctl enable ollama") + if return_code != 0: + print("Cannot enable ollama.") + sub = subprocess.Popen(["ollama", "serve"]) + return_code = os.system("ollama -v") + if return_code != 0: + print("Cannot serve ollama.") + ollama.pull(self.model_name) def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs): if is_port_in_use(self.port): diff --git a/text2text/pytorch_pretrained_bert/modeling.py b/text2text/pytorch_pretrained_bert/modeling.py index 1e3f218..b8ea0fe 100755 --- a/text2text/pytorch_pretrained_bert/modeling.py +++ b/text2text/pytorch_pretrained_bert/modeling.py @@ -160,8 +160,6 @@ def to_json_string(self): try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except ImportError: - print("Better speed can be achieved with apex installed.") - class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """Construct a layernorm module in the TF style (epsilon inside the square root).