-
-
Notifications
You must be signed in to change notification settings - Fork 613
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
268 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,15 @@ | ||
from .function_calls import OpenAISchema, openai_function, openai_schema | ||
from .dsl.multitask import MultiTask | ||
from .patch import patch | ||
from .sql import ChatCompletionSQL, MessageSQL, instrument_with_sqlalchemy | ||
|
||
__all__ = [ | ||
"OpenAISchema", | ||
"openai_function", | ||
"MultiTask", | ||
"openai_schema", | ||
"patch", | ||
"ChatCompletionSQL", | ||
"MessageSQL", | ||
"instrument_with_sqlalchemy", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from patch import instrument_with_sqlalchemy | ||
from sa import ChatCompletionSQL, MessageSQL, Session | ||
|
||
__all__ = ["instrument_with_sqlalchemy", "ChatCompletionSQL", "MessageSQL", "Session"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
TODO: add the created_at, id and called it completion vs response | ||
""" | ||
|
||
try: | ||
from sqlalchemy.orm import Session | ||
except ImportError: | ||
import warnings | ||
|
||
warnings.warn("SQLAlchemy is not installed. Please install it to use this feature.") | ||
|
||
import openai | ||
import inspect | ||
import json | ||
from typing import Callable | ||
from functools import wraps | ||
|
||
from sa import ChatCompletionSQL, MessageSQL | ||
|
||
|
||
def message_sql(index, message, is_response=False): | ||
return MessageSQL( | ||
index=index, | ||
content=message.get("content", None), | ||
role=message["role"], | ||
arguments=message.get("function_call", {}).get("arguments", None), | ||
name=message.get("function_call", {}).get("name", None), | ||
is_function_call="function_call" in message, | ||
is_response=is_response, | ||
) | ||
|
||
|
||
# Synchronous function to insert chat completion | ||
def sync_insert_chat_completion( | ||
engine, | ||
messages: list[dict], | ||
responses: list[dict] = [], | ||
**kwargs, | ||
): | ||
with Session(engine) as session: # type: ignore | ||
chat = ChatCompletionSQL( | ||
id=kwargs.pop("id", None), | ||
created_at=kwargs.pop("created", None), | ||
functions=json.dumps(kwargs.pop("functions", None)), | ||
function_call=json.dumps(kwargs.pop("function_call", None)), | ||
messages=[ | ||
message_sql(index=ii, message=message) | ||
for (ii, message) in enumerate(messages) | ||
], | ||
responses=[ | ||
message_sql(index=resp["index"], message=resp.message, is_response=True) # type: ignore | ||
for resp in responses | ||
], | ||
**kwargs, | ||
) | ||
session.add(chat) | ||
session.commit() | ||
|
||
|
||
def patch_with_engine(engine): | ||
def add_sql_alchemy(func: Callable) -> Callable: | ||
is_async = inspect.iscoroutinefunction(func) | ||
if is_async: | ||
|
||
@wraps(func) | ||
async def new_chatcompletion(*args, **kwargs): # type: ignore | ||
response = await func(*args, **kwargs) | ||
sync_insert_chat_completion( | ||
engine, | ||
messages=kwargs.pop("messages", []), | ||
responses=response.choices, | ||
id=response["id"], | ||
**response["usage"], | ||
**kwargs, | ||
) | ||
return response | ||
|
||
else: | ||
|
||
@wraps(func) | ||
def new_chatcompletion(*args, **kwargs): | ||
response = func(*args, **kwargs) | ||
|
||
sync_insert_chat_completion( | ||
engine, | ||
messages=kwargs.pop("messages", []), | ||
responses=response.choices, | ||
id=response["id"], | ||
**response["usage"], | ||
**kwargs, | ||
) | ||
response._completion_id = response["id"] | ||
return response | ||
|
||
return new_chatcompletion | ||
|
||
return add_sql_alchemy | ||
|
||
|
||
def instrument_with_sqlalchemy(engine): | ||
patcher = patch_with_engine(engine) | ||
original_chatcompletion = openai.ChatCompletion.create | ||
original_chatcompletion_async = openai.ChatCompletion.acreate | ||
openai.ChatCompletion.create = patcher(original_chatcompletion) | ||
openai.ChatCompletion.acreate = patcher(original_chatcompletion_async) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from datetime import datetime | ||
from sqlalchemy import ( | ||
Boolean, | ||
create_engine, | ||
Column, | ||
Integer, | ||
String, | ||
Float, | ||
ForeignKey, | ||
DateTime, | ||
) | ||
from sqlalchemy.orm import declarative_base, relationship, Session | ||
|
||
Base = declarative_base() | ||
|
||
|
||
class MessageSQL(Base): | ||
__tablename__ = "message" | ||
id = Column(Integer, primary_key=True, index=True) | ||
index = Column(Integer) | ||
role = Column(String) | ||
content = Column(String, index=True) | ||
is_function_call = Column(Boolean) | ||
arguments = Column(String) | ||
name = Column(String) | ||
|
||
is_response = Column(Boolean, default=False) | ||
chatcompletion_id = Column(String, ForeignKey("chatcompletion.id")) | ||
chatcompletion = relationship( | ||
"ChatCompletion", back_populates="messages", foreign_keys=[chatcompletion_id] | ||
) | ||
|
||
response_chatcompletion_id = Column(String, ForeignKey("chatcompletion.id")) | ||
response_chatcompletion = relationship( | ||
"ChatCompletion", | ||
back_populates="responses", | ||
foreign_keys=[response_chatcompletion_id], | ||
) | ||
|
||
|
||
class ChatCompletionSQL(Base): | ||
__tablename__ = "chatcompletion" | ||
id = Column(String, primary_key=True) | ||
|
||
messages = relationship( | ||
"Message", | ||
back_populates="chatcompletion", | ||
foreign_keys=[MessageSQL.chatcompletion_id], | ||
) | ||
responses = relationship( | ||
"Message", | ||
back_populates="response_chatcompletion", | ||
foreign_keys=[MessageSQL.response_chatcompletion_id], | ||
) | ||
|
||
created_at = Column(DateTime, default=datetime.utcnow) | ||
temperature = Column(Float) | ||
model = Column(String) | ||
max_tokens = Column(Integer) | ||
prompt_tokens = Column(Integer) | ||
completion_tokens = Column(Integer) | ||
total_tokens = Column(Integer) | ||
functions = Column(String) # TODO: make this a foreign key | ||
function_call = Column(String) # TODO: make this a foreign key | ||
|
||
|
||
def create_all(engine): | ||
Base.metadata.create_all(engine) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pytest | ||
|
||
try: | ||
from sqlalchemy import create_engine | ||
from instructor.sql import instrument_with_sqlalchemy | ||
|
||
engine = create_engine("sqlite:///chat.db", echo=True) | ||
|
||
instrument_with_sqlalchemy(engine) | ||
except ImportError: | ||
pytest.skip("SQLAlchemy not installed", allow_module_level=True) | ||
|
||
import openai | ||
from pydantic import BaseModel | ||
from instructor import OpenAISchema | ||
|
||
|
||
@pytest.mark.skip(reason="I didn't mock the create method") | ||
def test_normal(): | ||
resp = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo-0613", | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are a world class adder", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "1+1", | ||
}, | ||
], | ||
) | ||
assert "2" in resp.choices[0].message.content # type: ignore | ||
|
||
|
||
@pytest.mark.skip(reason="I didn't mock the create method") | ||
def test_schema(): | ||
class Add(OpenAISchema): | ||
a: int | ||
b: int | ||
|
||
resp = openai.ChatCompletion.create( | ||
model="gpt-3.5-turbo-0613", | ||
functions=[Add.openai_schema], | ||
function_call={"name": "Add"}, | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "You are a world class adder", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "1+1", | ||
}, | ||
], | ||
) | ||
add = Add.from_response(resp) | ||
assert add.a == 1 | ||
assert add.b == 1 | ||
|
||
|
||
@pytest.mark.skip(reason="I didn't mock the create method") | ||
def test_response_model(): | ||
from instructor import patch | ||
|
||
patch() | ||
|
||
class Add(BaseModel): | ||
a: int | ||
b: int | ||
|
||
add: Add = openai.ChatCompletion.create( | ||
response_model=Add, | ||
model="gpt-3.5-turbo-0613", | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "1+1", | ||
} | ||
], | ||
) # type: ignore | ||
assert add.a == 1 | ||
assert add.b == 1 |