Skip to content

Commit

Permalink
feat: add support for chat sessions
Browse files Browse the repository at this point in the history
Add new module, utils, to openai crate; add chat session name comment
and session argument to function in b crate; reorganize message
creation, update model and logic for setting stop words, and remove
HashMap creation in chats file of b crate; remove two error types and
add five new ones, and add file and directory existence checking and
HOME directory retrieval functions to error and utils modules in openai
crate respectively; add cloning and optional session attributes,
new_with_session and store_session methods, get_sessions_file,
deserialize_sessions_file, and serialize_sessions_file methods,
and imports to chats file in openai crate.

---

Changes to file crates/openai/src/lib.rs:

- Added a new module called "utils".

Changes to file crates/b/src/lib.rs:

- Added a comment about chat session name and its use
- Added an argument 'session' with Option<String> type to the function definition using the #[arg(long)] attribute.

Changes to file crates/b/src/chats.rs:

- Updated ChatsApi initialization
- Reorganized message creation
- Updated ChatsApi model, max_tokens, n, user, and stream
- Added check for system message
- Added logic for setting stop words
- Removed HashMap creation and used logit_bias attribute

Changes to file crates/openai/src/error.rs:

- Removed two error types: NoChoices and InvalidBestOf
- Added five new error types: DeserializationError, FileError, InvalidFrequencyPenalty, InvalidPresencePenalty, and NoSession
- Changed the error message of InvalidStop and InvalidLogProbs

Changes to file crates/openai/src/utils.rs:

- Added imports for std::env, std::fs, and std::path::Path
- Added three new functions for checking if a file or directory exists and getting the HOME directory
- The `file_exists` function checks if a file exists and returns a boolean
- The `directory_exists` function checks if a directory exists and returns a boolean
- The `get_home_directory` function returns the HOME directory as a string, appending "/.b/sessions" if available or returning "/tmp/.b/sessions" if not

Changes to file crates/openai/src/chats.rs:

- Added imports for std::fs, std::io, crate::utils
- Added 'Clone' and 'Option' attributes to the session field of ChatsApi struct
- Added new_with_session method to ChatsApi struct for creating new instances of the struct with a given session key
- Added store_session method to ChatsApi struct for storing the current session to a file
- Modified create_message method of ChatsApi struct to update session info in file
- Added get_sessions_file method for obtaining the path to the sessions file
- Added deserialize_sessions_file method for deserializing the sessions file
- Added serialize_sessions_file method for serializing the sessions file
  • Loading branch information
cloudbridgeuy committed Apr 18, 2023
1 parent 444dd40 commit af65fca
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 42 deletions.
71 changes: 41 additions & 30 deletions crates/b/src/chats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ impl ChatsCreateCommand {
match command {
ChatsCommands::Create {
model,
session,
prompt,
system,
max_tokens,
Expand All @@ -37,30 +38,34 @@ impl ChatsCreateCommand {
.as_ref()
.expect("No API Key provided")
.to_string();
let mut api = ChatsApi::new(api_key)?;

match prompt {
Some(s) if s == "-" => {
api.messages = vec![ChatMessage {
content: read_from_stdin()?,
role: "user".to_owned(),
}];
}
Some(s) => {
api.messages = vec![ChatMessage {
content: s.to_owned(),
role: "user".to_owned(),
}];
}
None => {
api.messages = vec![ChatMessage {
content: "".to_owned(),
role: "user".to_owned(),
}];
}
}

let mut api = if let Some(s) = session {
ChatsApi::new_with_session(api_key, s.to_owned())?
} else {
ChatsApi::new(api_key)?
};

let message = match prompt {
Some(s) if s == "-" => ChatMessage {
content: read_from_stdin()?,
role: "user".to_owned(),
},
Some(s) => ChatMessage {
content: s.to_owned(),
role: "user".to_owned(),
},
None => ChatMessage {
content: "".to_owned(),
role: "user".to_owned(),
},
};

api.messages.push(message);

if let Some(s) = system {
if api.messages.first().unwrap().role == "system" {
api.messages.remove(0);
}
api.messages.insert(
0,
ChatMessage {
Expand All @@ -70,21 +75,27 @@ impl ChatsCreateCommand {
);
}

api.model = model.to_owned();
api.max_tokens = max_tokens.to_owned();
api.n = *n;
api.user = user.to_owned();
api.stream = stream.to_owned();
if &api.model != model {
api.model = model.to_owned();
}

stop.as_ref()
.map(|s| api.set_stop(SingleOrVec::Vec(s.to_vec())));
max_tokens.map(|s| api.max_tokens = Some(s));
n.map(|s| api.n = Some(s));
stream.map(|s| api.stream = Some(s));
temperature.map(|s| api.set_temperature(s));
top_p.map(|s| api.set_top_p(s));
presence_penalty.map(|s| api.set_presence_penalty(s));
frequency_penalty.map(|s| api.set_frequency_penalty(s));

if &api.user != user {
api.user = user.to_owned();
}

stop.as_ref()
.map(|s| api.set_stop(SingleOrVec::Vec(s.to_vec())));

if let Some(logit_bias) = logit_bias {
let mut map = HashMap::new();
let mut map = api.logit_bias.unwrap_or(HashMap::new());
for (key, value) in logit_bias {
map.insert(key.to_owned(), *value);
}
Expand Down
3 changes: 3 additions & 0 deletions crates/b/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ pub enum ChatsCommands {
/// or see the following link: https://platform.openai.com/docs/models/overview
#[arg(long, default_value = "gpt-3.5-turbo")]
model: String,
/// Chat session name. Will be used to store previous session interactions.
#[arg(long)]
session: Option<String>,
/// The system message helps set the behavior of the assistant.
#[arg(long)]
system: Option<String>,
Expand Down
137 changes: 134 additions & 3 deletions crates/openai/src/chats.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use std::collections::HashMap;
use std::fs::{create_dir_all, File};
use std::io::{BufReader, BufWriter};

use log;
use serde::{Deserialize, Serialize};
use serde_either::SingleOrVec;

use crate::client::Client;
use crate::error;
use crate::utils::{directory_exists, file_exists, get_home_directory};

#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct ChatsApi {
#[serde(skip)]
client: Client,
Expand All @@ -34,9 +37,11 @@ pub struct ChatsApi {
pub logit_bias: Option<HashMap<u32, f32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session: Option<String>,
}

#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Serialize, Deserialize, Default, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
Expand Down Expand Up @@ -68,6 +73,7 @@ pub struct ChatUsage {
const DEFAULT_MODEL: &str = "gpt-3.5-turbo";

impl ChatsApi {
/// Creates a new ChatsApi instance.
pub fn new(api_key: String) -> Result<Self, error::OpenAi> {
let client = match Client::new(api_key) {
Ok(client) => client,
Expand All @@ -88,6 +94,35 @@ impl ChatsApi {
})
}

/// Creates a new ChatsApi instance by loading the sessions file
pub fn new_with_session(api_key: String, session: String) -> Result<Self, error::OpenAi> {
let session_file = get_sessions_file(&session)?;
let mut chats_api = deserialize_sessions_file(&session_file)?;

chats_api.client = match Client::new(api_key) {
Ok(client) => client,
Err(err) => {
return Err(error::OpenAi::ClientError {
body: err.to_string(),
});
}
};

log::debug!("Created OpenAi HTTP Client");

Ok(chats_api)
}

/// Stores the current session to a file.
pub fn store_session(&self) -> Result<(), error::OpenAi> {
if let Some(session) = &self.session {
let session_file = get_sessions_file(session)?;
serialize_sessions_file(&session_file, self)
} else {
Err(error::OpenAi::NoSession)
}
}

/// Gets the value of the temperature.
pub fn get_temperature(self) -> Option<f32> {
self.temperature
Expand Down Expand Up @@ -189,7 +224,11 @@ impl ChatsApi {

/// Creates a completion for the chat message
pub async fn create(&self) -> Result<Chat, error::OpenAi> {
let request = match serde_json::to_string(&self) {
let mut api = &mut (*self).clone();
let session = api.session.clone();
api.session = None;

let request = match serde_json::to_string(api) {
Ok(request) => request,
Err(err) => {
return Err(error::OpenAi::SerializationError {
Expand Down Expand Up @@ -225,6 +264,98 @@ impl ChatsApi {
}
};

if let Some(session) = session {
let session_file = get_sessions_file(&session)?;
api.session = Some(session);
api.messages
.push(body.choices.first().unwrap().message.clone());
serialize_sessions_file(&session_file, api)?;
}

Ok(body)
}
}

/// Get the path to the sessions file.
pub fn get_sessions_file(session: &str) -> Result<String, error::OpenAi> {
log::debug!("Getting sessions file: {}", session);

let home_dir = get_home_directory();

log::debug!("Home directory: {}", home_dir);

// Create the HOME directory if it doesn't exist
if !directory_exists(&home_dir) {
log::debug!("Creating home directory: {}", home_dir);
create_dir_all(&home_dir).unwrap();
}

let sessions_file = format!("{}/{}", home_dir, session);

// Create the sessions file if it doesn't exist
if !file_exists(&sessions_file) {
log::debug!("Creating sessions file: {}", sessions_file);
File::create(&sessions_file).unwrap();
let mut chats_api = ChatsApi::new(Default::default())?;
chats_api.session = Some(session.to_string());
chats_api.messages = Vec::new();
serialize_sessions_file(&sessions_file, &chats_api)?;
}

log::debug!("Sessions file: {}", sessions_file);

Ok(sessions_file)
}

/// Deserialize the sessions file.
pub fn deserialize_sessions_file(session_file: &str) -> Result<ChatsApi, error::OpenAi> {
log::debug!("Deserializing sessions file: {}", session_file);

let file = match File::open(session_file) {
Ok(file) => file,
Err(err) => {
return Err(error::OpenAi::FileError {
body: err.to_string(),
});
}
};

let reader = BufReader::new(file);

let chats_api: ChatsApi = match serde_json::from_reader(reader) {
Ok(chats_api) => chats_api,
Err(err) => {
return Err(error::OpenAi::DeserializationError {
body: err.to_string(),
});
}
};

Ok(chats_api)
}

/// Serialize the sessions file
pub fn serialize_sessions_file(
session_file: &str,
chats_api: &ChatsApi,
) -> Result<(), error::OpenAi> {
log::debug!("Serializing sessions file: {}", session_file);

let file = match File::create(session_file) {
Ok(file) => file,
Err(err) => {
return Err(error::OpenAi::FileError {
body: err.to_string(),
});
}
};

let writer = BufWriter::new(file);

match serde_json::to_writer_pretty(writer, &chats_api) {
Ok(_) => Ok(()),
Err(err) => Err(error::OpenAi::SerializationError {
body: err.to_string(),
}),
}
}
21 changes: 12 additions & 9 deletions crates/openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use custom_error::custom_error;

custom_error! {pub OpenAi
NoChoices = "no chat choices",
InvalidStop{stop: String} = "stop value ({stop}) must be either 'left' or 'right'",
RequestError{body: String} = "request error: {body}",
ModelNotFound{model_name: String} = "model not found: {model_name}",
SerializationError{body: String} = "serialization error: {body}",
ClientError{body: String} = "client error:\n{body}",
InvalidLogProbs{logprobs: f32} = "logprob value ({logprobs}) must be between 0 and 5",
DeserializationError{body: String} = "deserialization error: {body}",
FileError{body: String} = "file error: {body}",
InvalidBestOf = "'best_of' cannot be used with 'stream'",
InvalidEcho = "'echo' cannot be used with 'suffix'",
InvalidFrequencyPenalty{frequency_penalty: f32} = "frequency_penalty ({frequency_penalty}) must be between -2.0 and 2.0",
InvalidLogProbs{logprobs: f32} = "logprob value ({logprobs}) must be between 0 and 5",
InvalidPresencePenalty{presence_penalty: f32} = "presence_penalty value ({presence_penalty}) must be between -2.0 and 2.0",
InvalidStop{stop: String} = "stop value ({stop}) must be either 'left' or 'right'",
InvalidStream = "'stream' cannot be used with 'best_of'",
InvalidSuffix = "'suffix' cannot be used with 'echo'",
InvalidBestOf = "'best_of' cannot be used with 'stream'",
InvalidTemperature{temperature: f32} = "temperature value ({temperature}) must be between 0.0 and 2.0",
InvalidTopP{top_p: f32} = "top_p value ({top_p}) must be between 0 and 1",
InvalidPresencePenalty{presence_penalty: f32} = "presence_penalty value ({presence_penalty}) must be between -2.0 and 2.0",
InvalidFrequencyPenalty{frequency_penalty: f32} = "frequency_penalty ({frequency_penalty}) must be between -2.0 and 2.0",
ModelNotFound{model_name: String} = "model not found: {model_name}",
NoChoices = "no chat choices",
NoSession = "no session",
RequestError{body: String} = "request error: {body}",
SerializationError{body: String} = "serialization error: {body}",
}
1 change: 1 addition & 0 deletions crates/openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod completions;
pub mod edits;
pub mod error;
pub mod models;
pub mod utils;
22 changes: 22 additions & 0 deletions crates/openai/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::env;
use std::fs;
use std::path::Path;

/// Checks if a file exists.
pub fn file_exists(filename: &str) -> bool {
fs::metadata(filename).is_ok()
}

/// Chacks if a directory exists.
pub fn directory_exists(dir_name: &str) -> bool {
let path = Path::new(dir_name);
path.exists() && path.is_dir()
}

/// Get HOME directory.
pub fn get_home_directory() -> String {
match env::var("HOME") {
Ok(val) => val + "/.b/sessions",
Err(_) => String::from("/tmp/.b/sessions"),
}
}

0 comments on commit af65fca

Please sign in to comment.