Skip to content

Commit

Permalink
add sql
Browse files Browse the repository at this point in the history
  • Loading branch information
jxnl committed Aug 20, 2023
1 parent fbe5697 commit 08764e9
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/sqlmodel-integration/patch_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sql_message(index, message, is_response=False):
role=message["role"],
arguments=message.get("function_call", {}).get("arguments", None),
name=message.get("function_call", {}).get("name", None),
is_function_call = "function_call" in message,
is_function_call="function_call" in message,
is_response=is_response,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/sqlmodel-integration/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

engine = create_engine("sqlite:///chat.db", echo=True)
instrument_with_sqlalchemy(engine)

patch()


class Add(BaseModel):
a: int
b: int
Expand All @@ -26,4 +26,4 @@ class Add(BaseModel):
)

assert resp.a == 1
assert resp.b == 1
assert resp.b == 1
4 changes: 4 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .function_calls import OpenAISchema, openai_function, openai_schema
from .dsl.multitask import MultiTask
from .patch import patch
from .sql import ChatCompletionSQL, MessageSQL, instrument_with_sqlalchemy

__all__ = [
"OpenAISchema",
"openai_function",
"MultiTask",
"openai_schema",
"patch",
"ChatCompletionSQL",
"MessageSQL",
"instrument_with_sqlalchemy",
]
4 changes: 1 addition & 3 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ def new_chatcompletion(
return new_chatcompletion




def patch():
original_chatcompletion = openai.ChatCompletion.create
original_chatcompletion_async = openai.ChatCompletion.acreate
openai.ChatCompletion.create = wrap_chatcompletion(original_chatcompletion)
openai.ChatCompletion.acreate = wrap_chatcompletion(original_chatcompletion_async)
openai.ChatCompletion.acreate = wrap_chatcompletion(original_chatcompletion_async)
4 changes: 4 additions & 0 deletions instructor/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from patch import instrument_with_sqlalchemy
from sa import ChatCompletionSQL, MessageSQL, Session

__all__ = ["instrument_with_sqlalchemy", "ChatCompletionSQL", "MessageSQL", "Session"]
105 changes: 105 additions & 0 deletions instructor/sql/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
TODO: add the created_at, id and called it completion vs response
"""

try:
from sqlalchemy.orm import Session
except ImportError:
import warnings

warnings.warn("SQLAlchemy is not installed. Please install it to use this feature.")

import openai
import inspect
import json
from typing import Callable
from functools import wraps

from sa import ChatCompletionSQL, MessageSQL


def message_sql(index, message, is_response=False):
return MessageSQL(
index=index,
content=message.get("content", None),
role=message["role"],
arguments=message.get("function_call", {}).get("arguments", None),
name=message.get("function_call", {}).get("name", None),
is_function_call="function_call" in message,
is_response=is_response,
)


# Synchronous function to insert chat completion
def sync_insert_chat_completion(
engine,
messages: list[dict],
responses: list[dict] = [],
**kwargs,
):
with Session(engine) as session: # type: ignore
chat = ChatCompletionSQL(
id=kwargs.pop("id", None),
created_at=kwargs.pop("created", None),
functions=json.dumps(kwargs.pop("functions", None)),
function_call=json.dumps(kwargs.pop("function_call", None)),
messages=[
message_sql(index=ii, message=message)
for (ii, message) in enumerate(messages)
],
responses=[
message_sql(index=resp["index"], message=resp.message, is_response=True) # type: ignore
for resp in responses
],
**kwargs,
)
session.add(chat)
session.commit()


def patch_with_engine(engine):
def add_sql_alchemy(func: Callable) -> Callable:
is_async = inspect.iscoroutinefunction(func)
if is_async:

@wraps(func)
async def new_chatcompletion(*args, **kwargs): # type: ignore
response = await func(*args, **kwargs)
sync_insert_chat_completion(
engine,
messages=kwargs.pop("messages", []),
responses=response.choices,
id=response["id"],
**response["usage"],
**kwargs,
)
return response

else:

@wraps(func)
def new_chatcompletion(*args, **kwargs):
response = func(*args, **kwargs)

sync_insert_chat_completion(
engine,
messages=kwargs.pop("messages", []),
responses=response.choices,
id=response["id"],
**response["usage"],
**kwargs,
)
response._completion_id = response["id"]
return response

return new_chatcompletion

return add_sql_alchemy


def instrument_with_sqlalchemy(engine):
patcher = patch_with_engine(engine)
original_chatcompletion = openai.ChatCompletion.create
original_chatcompletion_async = openai.ChatCompletion.acreate
openai.ChatCompletion.create = patcher(original_chatcompletion)
openai.ChatCompletion.acreate = patcher(original_chatcompletion_async)
68 changes: 68 additions & 0 deletions instructor/sql/sa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from datetime import datetime
from sqlalchemy import (
Boolean,
create_engine,
Column,
Integer,
String,
Float,
ForeignKey,
DateTime,
)
from sqlalchemy.orm import declarative_base, relationship, Session

Base = declarative_base()


class MessageSQL(Base):
__tablename__ = "message"
id = Column(Integer, primary_key=True, index=True)
index = Column(Integer)
role = Column(String)
content = Column(String, index=True)
is_function_call = Column(Boolean)
arguments = Column(String)
name = Column(String)

is_response = Column(Boolean, default=False)
chatcompletion_id = Column(String, ForeignKey("chatcompletion.id"))
chatcompletion = relationship(
"ChatCompletion", back_populates="messages", foreign_keys=[chatcompletion_id]
)

response_chatcompletion_id = Column(String, ForeignKey("chatcompletion.id"))
response_chatcompletion = relationship(
"ChatCompletion",
back_populates="responses",
foreign_keys=[response_chatcompletion_id],
)


class ChatCompletionSQL(Base):
__tablename__ = "chatcompletion"
id = Column(String, primary_key=True)

messages = relationship(
"Message",
back_populates="chatcompletion",
foreign_keys=[MessageSQL.chatcompletion_id],
)
responses = relationship(
"Message",
back_populates="response_chatcompletion",
foreign_keys=[MessageSQL.response_chatcompletion_id],
)

created_at = Column(DateTime, default=datetime.utcnow)
temperature = Column(Float)
model = Column(String)
max_tokens = Column(Integer)
prompt_tokens = Column(Integer)
completion_tokens = Column(Integer)
total_tokens = Column(Integer)
functions = Column(String) # TODO: make this a foreign key
function_call = Column(String) # TODO: make this a foreign key


def create_all(engine):
Base.metadata.create_all(engine) # type: ignore
83 changes: 83 additions & 0 deletions tests/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest

try:
from sqlalchemy import create_engine
from instructor.sql import instrument_with_sqlalchemy

engine = create_engine("sqlite:///chat.db", echo=True)

instrument_with_sqlalchemy(engine)
except ImportError:
pytest.skip("SQLAlchemy not installed", allow_module_level=True)

import openai
from pydantic import BaseModel
from instructor import OpenAISchema


@pytest.mark.skip(reason="I didn't mock the create method")
def test_normal():
resp = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=[
{
"role": "system",
"content": "You are a world class adder",
},
{
"role": "user",
"content": "1+1",
},
],
)
assert "2" in resp.choices[0].message.content # type: ignore


@pytest.mark.skip(reason="I didn't mock the create method")
def test_schema():
class Add(OpenAISchema):
a: int
b: int

resp = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
functions=[Add.openai_schema],
function_call={"name": "Add"},
messages=[
{
"role": "system",
"content": "You are a world class adder",
},
{
"role": "user",
"content": "1+1",
},
],
)
add = Add.from_response(resp)
assert add.a == 1
assert add.b == 1


@pytest.mark.skip(reason="I didn't mock the create method")
def test_response_model():
from instructor import patch

patch()

class Add(BaseModel):
a: int
b: int

add: Add = openai.ChatCompletion.create(
response_model=Add,
model="gpt-3.5-turbo-0613",
messages=[
{
"role": "user",
"content": "1+1",
}
],
) # type: ignore
assert add.a == 1
assert add.b == 1

0 comments on commit 08764e9

Please sign in to comment.