diff --git a/Cargo.lock b/Cargo.lock index 714fb66..6c40e38 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1510,6 +1510,7 @@ dependencies = [ "serde_json", "serde_yaml 0.9.19", "tokio", + "tokio-stream", ] [[package]] @@ -2306,6 +2307,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.7" diff --git a/crates/openai/Cargo.toml b/crates/openai/Cargo.toml index 803e26d..4961bce 100644 --- a/crates/openai/Cargo.toml +++ b/crates/openai/Cargo.toml @@ -29,3 +29,4 @@ serde_yaml = "0.9.19" # YAML data format for Serde tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.… reqwest-eventsource = "0.4.0" futures = "0.3.28" +tokio-stream = "0.1.14" diff --git a/crates/openai/src/chats.rs b/crates/openai/src/chats.rs index 843117a..b3a1bb2 100644 --- a/crates/openai/src/chats.rs +++ b/crates/openai/src/chats.rs @@ -2,16 +2,19 @@ use std::collections::HashMap; use std::fs::{create_dir_all, File}; use std::io::{BufReader, BufWriter}; -use futures::StreamExt; +use crate::client::Client; +use crate::error; +use crate::utils::{directory_exists, file_exists, get_home_directory}; use gpt_tokenizer::Default as DefaultTokenizer; use log; use serde::{Deserialize, Serialize}; use serde_either::SingleOrVec; use serde_json::Value; - -use crate::client::Client; -use crate::error; -use crate::utils::{directory_exists, file_exists, get_home_directory}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::Mutex; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::{Stream, StreamExt}; #[derive(Debug, Serialize, Deserialize, Default, Clone)] pub struct ChatsApi { @@ -270,8 +273,14 @@ impl ChatsApi { Ok(self) } - /// Creates a completion for the chat message - pub async fn create(&self) -> Result { + /// Creates a completion for the chat message in stream format. + pub async fn create_stream( + &self, + ) -> Result>, error::OpenAi> { + if Some(true) == self.stream { + return Err(error::OpenAi::InvalidStream); + } + let mut api = &mut (*self).clone(); let min_available_tokens = api.min_available_tokens.unwrap_or(750); @@ -307,25 +316,27 @@ impl ChatsApi { log::debug!("Request: {}", request); - if let Some(true) = &self.stream { - log::debug!("Streaming completion"); - let mut event_source = match self.client.post_stream("/chat/completions", request).await - { - Ok(response) => response, - Err(e) => { - return Err(error::OpenAi::RequestError { - body: e.to_string(), - }) - } - }; + log::debug!("Streaming completion"); + let mut event_source = match self.client.post_stream("/chat/completions", request).await { + Ok(response) => response, + Err(e) => { + return Err(error::OpenAi::RequestError { + body: e.to_string(), + }) + } + }; + + let (tx, rx) = mpsc::channel(100); + let acc = Arc::new(Mutex::new(String::new())); + let acc_clone = Arc::clone(&acc); - let mut acc: String = String::new(); + tokio::spawn(async move { while let Some(ev) = event_source.next().await { match ev { - Err(e) => { - return Err(error::OpenAi::RequestError { - body: e.to_string(), - }) + Err(_) => { + if tx.send(Err(error::OpenAi::StreamError)).await.is_err() { + return; + } } Ok(event) => match event { reqwest_eventsource::Event::Open { .. } => {} @@ -333,103 +344,138 @@ impl ChatsApi { log::debug!("Message: {:?}", message); if message.data == "[DONE]" { - break; + return; } - let response: Chunk = match serde_json::from_str(&message.data) { - Err(e) => { - return Err(error::OpenAi::SerializationError { - body: e.to_string(), - }) + let chunk: Chunk = match serde_json::from_str(&message.data) { + Err(_) => { + if tx.send(Err(error::OpenAi::StreamError)).await.is_err() { + return; + } + return; } Ok(output) => output, }; - log::debug!("Response: {:?}", response); + log::debug!("Response: {:?}", chunk); - if let Some(choice) = response.choices.get(0) { + if let Some(choice) = &chunk.choices.get(0) { if let Some(delta) = &choice.delta { if let Some(content) = &delta.content { - print!("{}", content); - acc.push_str(content); + let mut accumulator = acc.lock().await; + accumulator.push_str(&content.clone()); } } } + + if tx.send(Ok(chunk)).await.is_err() { + return; + } } }, } } + }); + + log::debug!("Checking for session, {:?}", session); + if let Some(session) = session { + let session_file = get_sessions_file(&session)?; + api.session = Some(session); + api.min_available_tokens = Some(min_available_tokens); + api.max_supported_tokens = Some(max_supported_tokens); + api.messages = messages; + + let data = acc_clone.lock().await; + let data_string = &*data; + + api.messages.push(ChatMessage { + content: Some(data_string.to_string()), + role: "assistant".to_string(), + ..Default::default() + }); + serialize_sessions_file(&session_file, api)?; + } - log::debug!("Returning acc: {}", acc); - - log::debug!("Checking for session, {:?}", session); - if let Some(session) = session { - let session_file = get_sessions_file(&session)?; - api.session = Some(session); - api.min_available_tokens = Some(min_available_tokens); - api.max_supported_tokens = Some(max_supported_tokens); - api.messages = messages; - api.messages.push(ChatMessage { - content: Some(acc.clone()), - role: "assistant".to_string(), - ..Default::default() + Ok(ReceiverStream::from(rx)) + } + + /// Creates a completion for the chat message + pub async fn create(&self) -> Result { + let mut api = &mut (*self).clone(); + + let min_available_tokens = api.min_available_tokens.unwrap_or(750); + let max_supported_tokens = api.max_supported_tokens.unwrap_or(4096); + let session = api.session.clone(); + let messages = api.messages.clone(); + + api.session = None; + api.min_available_tokens = None; + api.max_supported_tokens = None; + api.messages = trim_messages( + api.messages.clone(), + max_supported_tokens - min_available_tokens, + )? + .iter() + .map(|m| ChatMessage { + role: m.role.clone(), + content: m.content.clone(), + ..Default::default() + }) + .collect(); + + log::debug!("Trimmed messages to {:?}", api.messages); + + let request = match serde_json::to_string(api) { + Ok(request) => request, + Err(err) => { + return Err(error::OpenAi::SerializationError { + body: err.to_string(), }); - serialize_sessions_file(&session_file, api)?; } + }; - Ok(Chat { - choices: vec![ChatChoice { - message: ChatMessage { - content: Some(acc.clone()), - role: "assistant".to_string(), - ..Default::default() - }, - ..Default::default() - }], - ..Default::default() - }) - } else { - let body = match self.client.post("/chat/completions", request).await { - Ok(response) => match response.text().await { - Ok(text) => text, - Err(e) => { - return Err(error::OpenAi::RequestError { - body: e.to_string(), - }) - } - }, + log::debug!("Request: {}", request); + + let body = match self.client.post("/chat/completions", request).await { + Ok(response) => match response.text().await { + Ok(text) => text, Err(e) => { return Err(error::OpenAi::RequestError { body: e.to_string(), }) } - }; + }, + Err(e) => { + return Err(error::OpenAi::RequestError { + body: e.to_string(), + }) + } + }; - log::debug!("Response: {}", body); + log::debug!("Response: {}", body); - let body: Chat = match serde_json::from_str(&body) { - Ok(body) => body, - Err(e) => { - return Err(error::OpenAi::RequestError { - body: e.to_string(), - }) - } - }; - - log::debug!("Checking for session, {:?}", session); - if let Some(session) = session { - let session_file = get_sessions_file(&session)?; - api.session = Some(session); - api.min_available_tokens = Some(min_available_tokens); - api.max_supported_tokens = Some(max_supported_tokens); - api.messages = messages; - api.messages - .push(body.choices.first().unwrap().message.clone()); - serialize_sessions_file(&session_file, api)?; + let body: Chat = match serde_json::from_str(&body) { + Ok(body) => body, + Err(e) => { + return Err(error::OpenAi::RequestError { + body: e.to_string(), + }) } + }; - Ok(body) + log::debug!("Checking for session, {:?}", session); + if let Some(session) = session { + let session_file = get_sessions_file(&session)?; + api.session = Some(session); + api.min_available_tokens = Some(min_available_tokens); + api.max_supported_tokens = Some(max_supported_tokens); + api.messages = messages; + api.messages + .push(body.choices.first().unwrap().message.clone()); + serialize_sessions_file(&session_file, api)?; } + + Ok(body) } } diff --git a/crates/openai/src/error.rs b/crates/openai/src/error.rs index 36f5b12..6d2d2ed 100644 --- a/crates/openai/src/error.rs +++ b/crates/openai/src/error.rs @@ -19,6 +19,7 @@ custom_error! {pub OpenAi NoSession = "no session", RequestError{body: String} = "request error: {body}", SerializationError{body: String} = "serialization error: {body}", + StreamError = "stream error", TrimError = "could not find a message to trim", UknownError = "unknown error", }