Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/108 build a way for non engineer to iterate on the whisper prompt #109

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
import asyncio
import logging
import chromadb

from time import sleep
from openai import OpenAI
from langdetect import detect
from datetime import datetime, timezone
from Akvo_rabbitmq_client import rabbitmq_client
from typing import Optional
from db import connect_to_sqlite, get_stable_prompt

logger = logging.getLogger(__name__)


openai = OpenAI()

CHROMADB_HOST: str = os.getenv("CHROMADB_HOST")
CHROMADB_PORT: int = os.getenv("CHROMADB_PORT")
CHROMADB_DISTANCE_CUTOFF: float = float(os.getenv("CHROMADB_DISTANCE_CUTOFF"))
Expand Down Expand Up @@ -311,44 +315,46 @@ async def on_message(body: str) -> None:
await publish_reliably(queue_message=reply_message)


# Connect to all knowledge bases and store the language-specific connections
# and prompts in the assistant data dictionary.
for language in ASSISTANT_LANGUAGES:
collection_name = CHROMADB_COLLECTION_TEMPLATE.format(language)
knowledge_base = connect_to_chromadb(
CHROMADB_HOST, CHROMADB_PORT, collection_name
)
def main():
# Connect to all knowledge bases and store the language-specific connections
# and prompts in the assistant data dictionary.
sqlite_conn = connect_to_sqlite()
assert sqlite_conn is not None

system_prompt = os.getenv(f"SYSTEM_PROMPT_{language}")
assert (
system_prompt is not None
), f"missing environment variable SYSTEM_PROMPT_{language}"
assert isinstance(system_prompt, str)
assert len(system_prompt) > 0
rag_prompt = os.getenv(f"RAG_PROMPT_{language}")
assert (
rag_prompt is not None
), f"missing environment variable RAG_PROMPT_{language}"
assert isinstance(rag_prompt, str)
assert len(rag_prompt) > 0
ragless_prompt = os.getenv(f"RAGLESS_PROMPT_{language}")
assert (
ragless_prompt is not None
), f"missing environment variable RAGLESS_PROMPT_{language}"
assert isinstance(ragless_prompt, str)
assert len(ragless_prompt) > 0

assistant_data[language] = {
"knowledge_base": knowledge_base,
"system_prompt": system_prompt,
"rag_prompt": rag_prompt,
"ragless_prompt": ragless_prompt,
}
for language in ASSISTANT_LANGUAGES:
collection_name = CHROMADB_COLLECTION_TEMPLATE.format(language)
knowledge_base = connect_to_chromadb(
CHROMADB_HOST, CHROMADB_PORT, collection_name
)

openai = OpenAI()
prompt = get_stable_prompt(
kjkoster marked this conversation as resolved.
Show resolved Hide resolved
conn=sqlite_conn,
language=language,
)
assert prompt is not None

system_prompt = prompt["system_prompt"]
assert isinstance(system_prompt, str)
assert len(system_prompt) > 0

async def main():
rag_prompt = prompt["rag_prompt"]
assert isinstance(rag_prompt, str)
assert len(rag_prompt) > 0

ragless_prompt = prompt["ragless_prompt"]
assert isinstance(ragless_prompt, str)
assert len(ragless_prompt) > 0

assistant_data[language] = {
"knowledge_base": knowledge_base,
"system_prompt": system_prompt,
"rag_prompt": rag_prompt,
"ragless_prompt": ragless_prompt,
}
sqlite_conn.close()


async def consumer():
await rabbitmq_client.initialize()

await rabbitmq_client.consume(
Expand All @@ -367,5 +373,5 @@ async def main():

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

asyncio.run(main())
main()
asyncio.run(consumer())
1 change: 1 addition & 0 deletions assistant/db/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.sqlite
2 changes: 2 additions & 0 deletions assistant/db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .connection import DB_NAME, DB_PATH, connect_to_sqlite # noqa
from .prompt_query import get_stable_prompt # noqa
21 changes: 21 additions & 0 deletions assistant/db/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sqlite3
import logging

from typing import Optional


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


DB_NAME = "assistant"
DB_PATH = "./db/{db_name}.sqlite"
kjkoster marked this conversation as resolved.
Show resolved Hide resolved


def connect_to_sqlite(db_name: Optional[str] = DB_NAME):
try:
conn = sqlite3.connect(DB_PATH.format(db_name=db_name))
return conn
except Exception as e:
logger.warning(f"Error connecting to SQLite: {e}")
return None
20 changes: 20 additions & 0 deletions assistant/db/prompt_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd

from sqlite3 import Connection


def get_stable_prompt(conn: Connection, language: str):
query = """
SELECT *
FROM prompt_detail pd
LEFT JOIN prompt p
ON pd.prompt_id == p.id
WHERE p.stable == 1
kjkoster marked this conversation as resolved.
Show resolved Hide resolved
AND pd.language == '{language}'
ORDER BY pd.id DESC
"""
query = query.format(language=language)
df = pd.read_sql_query(query, conn)
if df.empty:
return None
return df.iloc[0]
1 change: 1 addition & 0 deletions assistant/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ black==24.4.2
flake8==7.1.0
langdetect==1.0.9
pytest==8.3.2
pandas==1.5.3
1 change: 1 addition & 0 deletions assistant/run.prod.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env bash
set -eu

python -m seeder.generate_prompt_sqlite
kjkoster marked this conversation as resolved.
Show resolved Hide resolved
python ./assistant.py
1 change: 1 addition & 0 deletions assistant/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ pip -q install --cache-dir=.pip -r requirements.txt

pip -q install --cache-dir=.pip -e /lib/Akvo_rabbitmq_client

python -m seeder.generate_prompt_sqlite
python ./assistant.py
Empty file added assistant/seeder/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions assistant/seeder/generate_prompt_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sqlite3
import pandas as pd
import logging

from db import DB_NAME, DB_PATH
from typing import Optional


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

FILENAME = "prompt"


def main(db_name: Optional[str] = DB_NAME):
csv_file = f"./sources/{FILENAME}.csv"
sqlite_db = DB_PATH.format(db_name=db_name)

if not os.path.exists(csv_file):
logger.info(f"404 - File not found: {csv_file}")
return None

try:
prompt_detail = pd.read_csv(csv_file)
prompt_group_df = pd.DataFrame(
{
"id": range(1, len(prompt_detail["stable"].unique()) + 1),
"stable": prompt_detail["stable"].unique(),
}
)
prompt_detail["prompt_id"] = prompt_detail["stable"].map(
prompt_group_df.set_index("stable")["id"]
)
prompt_detail = prompt_detail.drop(columns=["stable"])
except Exception as e:
logger.info(f"Error reading CSV: {e}")
return None

if prompt_detail.empty or prompt_group_df.empty:
logger.info("404 - CSV file is empty")
return None

try:
conn = sqlite3.connect(sqlite_db)
prompt_group_df.to_sql(
"prompt", conn, if_exists="replace", index=False
)
prompt_detail.to_sql(
"prompt_detail", conn, if_exists="replace", index=False
)
conn.close()
logger.info(f"{sqlite_db} generated successfully")
except Exception as e:
logger.info(f"Error writing to SQLite: {e}")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions assistant/sources/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.~lock*
4 changes: 4 additions & 0 deletions assistant/sources/prompt.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
id,language,system_prompt,rag_prompt,ragless_prompt,stable
1,en,"Your name is James Otieno. You are a research assistant who finds answers for African Extension Officers, specialised in agriculture in temperate regions. You cannot answer questions that are not about agriculture. You do speak English and all African languages. You always respond in English, regardless of what language the question was asked in. Your answers need to fit well on a WhatsApp screen, so keep your answers short. Answers of around 200 characters in length are best. Your replies are easy to understand by a smallholder farmer who has no formal education.","{prompt}. In your answer, use the following information if it is related: {context}",{prompt},TRUE
2,fr,"Votre nom est James Otieno. Vous êtes un assistant de recherche qui trouve des réponses pour les agents de vulgarisation africains, spécialisés en agriculture dans les régions tempérées. Vous ne pouvez pas répondre aux questions qui ne concernent pas l'agriculture. Vous parlez en français et toutes les langues africaines. Vous répondez toujours en français, quelle que soit la langue dans laquelle la question a été posée. Vos réponses doivent bien s’adapter à un écran WhatsApp, alors gardez-les courtes. Les réponses d’environ 200 caractères sont les meilleures. Vos réponses sont faciles à comprendre pour un petit agriculteur sans éducation formelle.","{prompt}. Dans votre réponse, utilisez les informations suivantes si elles sont liées : {context}",{prompt},TRUE
3,sw,"Jina lako ni James Otieno. Wewe ni mtafiti msaidizi ambaye hupata majibu kwa Maafisa Ugani wa Afrika, waliobobea katika kilimo katika maeneo yenye hali ya hewa baridi. Huwezi kujibu maswali ambayo hayahusu kilimo. Unazungumza Kiswahili na lugha zote za Kiafrika. Unajibu kwa Kiswahili kila wakati, bila kujali swali liliulizwa kwa lugha gani. Majibu yako yanahitaji kutoshea vizuri kwenye skrini ya WhatsApp, kwa hivyo majibu yako yawe mafupi. Majibu ya takriban herufi 200 kwa urefu ndiyo bora zaidi. Majibu yako ni rahisi kueleweka na mkulima mdogo ambaye hana elimu rasmi.","{prompt}. Katika jibu lako, tumia habari ifuatayo ikiwa inahusiana: {context}",{prompt},TRUE
85 changes: 71 additions & 14 deletions assistant/tests/test_assistant.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,64 @@
import logging

from os import environ
from assistant import get_language, assistant_data, query_llm
from db import connect_to_sqlite, get_stable_prompt
from assistant import main, get_language, assistant_data, query_llm
from unittest.mock import patch, MagicMock


def test_connect_to_sqlite():
conn = connect_to_sqlite()
assert conn is not None


def test_get_stable_prompt():
conn = connect_to_sqlite()
prompt = get_stable_prompt(conn=conn, language="id")
assert prompt is None

prompt = get_stable_prompt(conn=conn, language="en")
assert prompt is not None
assert prompt["language"] == "en"
assert prompt["system_prompt"] is not None
assert prompt["rag_prompt"] is not None
assert prompt["ragless_prompt"] is not None

prompt = get_stable_prompt(conn=conn, language="fr")
assert prompt is not None
assert prompt["language"] == "fr"
assert prompt["system_prompt"] is not None
assert prompt["rag_prompt"] is not None
assert prompt["ragless_prompt"] is not None

prompt = get_stable_prompt(conn=conn, language="sw")
assert prompt is not None
assert prompt["language"] == "sw"
assert prompt["system_prompt"] is not None
assert prompt["rag_prompt"] is not None
assert prompt["ragless_prompt"] is not None
conn.close()


def test_language_support():
main()
conn = connect_to_sqlite()
en_prompt = get_stable_prompt(conn=conn, language="en")
fr_prompt = get_stable_prompt(conn=conn, language="fr")
sw_prompt = get_stable_prompt(conn=conn, language="sw")
conn.close()
# not defined language default to english
detected_language = get_language(
"Halo, tolong kirimkan saya rekomendasi tentang pertanian di Kenya."
)
assert detected_language == "en"
knowledge_base = assistant_data[detected_language]["knowledge_base"]
system_prompt = assistant_data[detected_language]["system_prompt"]
rag_prompt = assistant_data[detected_language]["rag_prompt"]
ragless_prompt = assistant_data[detected_language]["ragless_prompt"]
assert knowledge_base.name == "EPPO-datasheets-en"
assert system_prompt == en_prompt["system_prompt"]
assert rag_prompt == en_prompt["rag_prompt"]
assert ragless_prompt == en_prompt["ragless_prompt"]

detected_language = get_language(
"Bonjour, veuillez m'envoyer les recommandations d'agriculture au Kenya"
)
Expand All @@ -15,9 +68,9 @@ def test_language_support():
rag_prompt = assistant_data[detected_language]["rag_prompt"]
ragless_prompt = assistant_data[detected_language]["ragless_prompt"]
assert knowledge_base.name == "EPPO-datasheets-fr"
assert system_prompt == environ["SYSTEM_PROMPT_fr"]
assert rag_prompt == environ["RAG_PROMPT_fr"]
assert ragless_prompt == environ["RAGLESS_PROMPT_fr"]
assert system_prompt == fr_prompt["system_prompt"]
assert rag_prompt == fr_prompt["rag_prompt"]
assert ragless_prompt == fr_prompt["ragless_prompt"]

detected_language = get_language(
"Hello, please send me the recommendations of agriculture in Kenya"
Expand All @@ -28,9 +81,9 @@ def test_language_support():
rag_prompt = assistant_data[detected_language]["rag_prompt"]
ragless_prompt = assistant_data[detected_language]["ragless_prompt"]
assert knowledge_base.name == "EPPO-datasheets-en"
assert system_prompt == environ["SYSTEM_PROMPT_en"]
assert rag_prompt == environ["RAG_PROMPT_en"]
assert ragless_prompt == environ["RAGLESS_PROMPT_en"]
assert system_prompt == en_prompt["system_prompt"]
assert rag_prompt == en_prompt["rag_prompt"]
assert ragless_prompt == en_prompt["ragless_prompt"]

detected_language = get_language(
"Hujambo, tafadhali nitumie mapendekezo ya kilimo nchini Kenya"
Expand All @@ -41,13 +94,17 @@ def test_language_support():
rag_prompt = assistant_data[detected_language]["rag_prompt"]
ragless_prompt = assistant_data[detected_language]["ragless_prompt"]
assert knowledge_base.name == "EPPO-datasheets-sw"
assert system_prompt == environ["SYSTEM_PROMPT_sw"]
assert rag_prompt == environ["RAG_PROMPT_sw"]
assert ragless_prompt == environ["RAGLESS_PROMPT_sw"]
assert system_prompt == sw_prompt["system_prompt"]
assert rag_prompt == sw_prompt["rag_prompt"]
assert ragless_prompt == sw_prompt["ragless_prompt"]


def test_query_llm():
with patch("assistant.OpenAI") as mock_openai:
conn = connect_to_sqlite()
en_prompt = get_stable_prompt(conn=conn, language="en")
conn.close()

mock_content = MagicMock()
mock_content.content = "Mocked response"
mock_choices = MagicMock()
Expand All @@ -62,9 +119,9 @@ def test_query_llm():

llm_client = mock_openai()
model = "my_model"
system_prompt = environ["SYSTEM_PROMPT_en"]
ragless_prompt_template = environ["RAGLESS_PROMPT_en"]
rag_prompt_template = environ["RAG_PROMPT_en"]
system_prompt = en_prompt["system_prompt"]
ragless_prompt_template = en_prompt["ragless_prompt"]
rag_prompt_template = en_prompt["rag_prompt"]
context = ["Knowledge base chunk 1", "Knowledge base chunk 2"]
prompt = (
"Hello, please send me the recommendations of agriculture in Kenya"
Expand Down
Loading