Skip to content

Commit

Permalink
feat(openai): add system option
Browse files Browse the repository at this point in the history
- Added `system` option to `Session`
- Added logic to handle `system` option

fix(openai): make Role lowercase
- Made `Role` enum `lowercase`

fix(openai): add as_str to Model
- Added `as_str()` method to `Model` enum

fix(openai): change model option to enum
- Changed `model` option type to `Model` enum

fix(openai): update API URL
- Changed OpenAI API URL

fix(openai): update complete methods
- Updated `complete()` and `complete_stream()` methods

refactor(openai): rename variables                                                                                                                                                                                                                                                                                                     - No variable renames
  • Loading branch information
cloudbridgeuy committed Aug 8, 2023
1 parent 3869ebe commit b3b6211
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
44 changes: 39 additions & 5 deletions crates/c/src/commands/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio_stream::{Stream, StreamExt};
use ulid::Ulid;

#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
#[default]
/// The user is a human
Expand Down Expand Up @@ -244,6 +245,23 @@ impl Session {
self.meta.format = options.format.unwrap();
}

if options.system.is_some() {
if let Some(m) = self.messages.first_mut() {
if m.role == Role::System {
m.content = options.system.unwrap();
}
} else {
self.messages.insert(
0,
HistoryMessage {
content: options.system.unwrap(),
role: Role::System,
pin: true,
},
);
}
}

self.meta.key = options.openai_api_key;
self.meta.stream = options.stream;
self.meta.silent = options.silent;
Expand Down Expand Up @@ -339,6 +357,17 @@ enum Model {
GPT35Turbo16K,
}

impl Model {
pub fn as_str(&self) -> &'static str {
match self {
Model::GPT4 => "gpt-4",
Model::GPT432K => "gpt-4-32k",
Model::GPT35Turbo => "gpt-3.5-turbo",
Model::GPT35Turbo16K => "gpt-3.5-turbo-16k",
}
}
}

#[derive(Debug, Serialize)]
pub struct CompleteRequestBody {
pub model: String,
Expand Down Expand Up @@ -421,8 +450,8 @@ pub struct Options {
/// API, stdin taking precedence.
prompt: Option<String>,
/// ID of the model to use. See the following link: https://platform.openai.com/docs/models/overview
#[clap(long)]
model: Option<String>,
#[clap(short, long, value_enum, default_value = "gpt4")]
model: Option<Model>,
/// Chat session name. Will be used to store previous session interactions.
#[arg(long)]
session: Option<String>,
Expand Down Expand Up @@ -660,7 +689,7 @@ async fn complete_stream(session: &Session) -> Result<impl Stream<Item = Result<
tracing::event!(tracing::Level::INFO, "Creating client...");
let client = Client::new(session.meta.key.clone())?;

let mut event_source = client.post_stream("/chat/completions", body).await?;
let mut event_source = client.post_stream("/v1/chat/completions", body).await?;

let (tx, rx) = mpsc::channel(100);

Expand Down Expand Up @@ -742,13 +771,18 @@ async fn complete(session: &Session) -> Result<Response> {
tracing::event!(tracing::Level::INFO, "Creating client...");
let client = Client::new(session.meta.key.clone())?;

let res = client.post("/v1/complete", body).await?;
let res = client.post("/v1/chat/completions", body.clone()).await?;
tracing::event!(tracing::Level::INFO, "res: {:?}", res);

let text = res.text().await?;
tracing::event!(tracing::Level::INFO, "text: {:?}", text);

let response: Response = serde_json::from_str(&text)?;
let response: Response = serde_json::from_str(&text).map_err(|e| {
tracing::event!(tracing::Level::ERROR, "Error parsing response text.");
tracing::event!(tracing::Level::ERROR, "body: {body}");
tracing::event!(tracing::Level::ERROR, "text: {text}");
color_eyre::eyre::format_err!("error: {e}")
})?;
tracing::event!(tracing::Level::INFO, "response: {:?}", response);

Ok(response)
Expand Down
2 changes: 1 addition & 1 deletion crates/openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct Client {
headers: HeaderMap,
}

const OPEN_API_URL: &str = "https://api.openai.com/v1";
const OPEN_API_URL: &str = "https://api.openai.com";

fn create_headers(api_key: String) -> Result<HeaderMap, error::OpenAi> {
let mut auth = String::from("Bearer ");
Expand Down

0 comments on commit b3b6211

Please sign in to comment.