Skip to content

Commit

Permalink
Assistant startup fix
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Sep 22, 2024
1 parent 415b798 commit ab6161a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.5.1",
version="1.5.2",
author="artitw",
author_email="[email protected]",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand Down
27 changes: 18 additions & 9 deletions text2text/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ollama
import psutil
import subprocess
import time

from llama_index.llms.ollama import Ollama
from llama_index.core.llms import ChatMessage
Expand All @@ -19,19 +20,27 @@ def __init__(self, **kwargs):
self.port = kwargs.get("port", 11434)
self.model_url = f"{self.host}:{self.port}"
self.model_name = kwargs.get("model_name", "llama3.1")
return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh")
if return_code != 0:
print("Cannot install ollama.")
return_code = os.system("sudo systemctl enable ollama")
self.load_model()
self.client = ollama.Client(host=self.model_url)

def __del__(self):
ollama.delete(self.model_name)

def load_model(self):
sub = subprocess.Popen(
f"ollama serve & ollama pull {self.model_name} & ollama run {self.model_name}",
shell=True,
stdout=subprocess.PIPE
)
return_code = os.system("sudo apt install -q -y lshw")
if return_code != 0:
print("Cannot install lshw.")
return_code = os.system("curl -fsSL https://ollama.com/install.sh | sh")
if return_code != 0:
print("Cannot install ollama.")
return_code = os.system("sudo systemctl enable ollama")
if return_code != 0:
print("Cannot enable ollama.")
sub = subprocess.Popen(["ollama", "serve"])
return_code = os.system("ollama -v")
if return_code != 0:
print("Cannot serve ollama.")
ollama.pull(self.model_name)

def chat_completion(self, messages=[{"role": "user", "content": "hello"}], stream=False, schema=None, **kwargs):
if is_port_in_use(self.port):
Expand Down
2 changes: 0 additions & 2 deletions text2text/pytorch_pretrained_bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,6 @@ def to_json_string(self):
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed.")

class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
Expand Down

0 comments on commit ab6161a

Please sign in to comment.