diff --git a/Cargo.lock b/Cargo.lock index 6c40e38..1078c61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde_json", "serde_yaml 0.9.19", "tokio", + "tokio-stream", ] [[package]] diff --git a/crates/b/Cargo.toml b/crates/b/Cargo.toml index f412400..d7a753b 100644 --- a/crates/b/Cargo.toml +++ b/crates/b/Cargo.toml @@ -32,4 +32,5 @@ async-trait = "0.1.68" # Type erasure for async trait methods tokio = { version = "1.27.0", features = ["full"] } # An event-driven, non-blocking I/O platform for writing asynchronous I/O backed applications.… indicatif = "0.17.3" # A progress bar and cli reporting library for Rust anyhow = "1.0.71" # Flexible concrete Error type built on std::error::Error +tokio-stream = "0.1.14" diff --git a/crates/b/src/chats.rs b/crates/b/src/chats.rs index 107d823..476b034 100644 --- a/crates/b/src/chats.rs +++ b/crates/b/src/chats.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::error::Error; +use tokio_stream::StreamExt; use async_trait::async_trait; use serde_either::SingleOrVec; @@ -8,7 +9,7 @@ use serde_json::from_str; use openai::chats::{Chat, ChatMessage, ChatsApi}; use openai::error::OpenAi as OpenAiError; -use crate::utils::read_from_stdin; +use crate::utils::{read_from_stdin, Spinner}; use crate::{ChatsCommands, Cli, CommandError, CommandHandle, CommandResult}; pub struct ChatsCreateCommand { @@ -165,6 +166,49 @@ impl CommandHandle for ChatsCreateCommand { type CallError = OpenAiError; async fn call(&self) -> Result { - self.api.create().await + let mut spinner = Spinner::new(false); + + log::debug!("Stream is: {:?}", self.api.stream); + + if Some(true) == self.api.stream { + log::debug!("Creating stream"); + + let chunks = match self.api.create_stream().await { + Ok(chunks) => chunks, + Err(e) => { + log::error!("Error creating stream: {}", e); + return Err(OpenAiError::StreamError); + } + }; + + tokio::pin!(chunks); + + while let Some(chunk) = chunks.next().await { + if chunk.is_err() { + log::error!("Error reading stream"); + spinner.err("Error reading stream"); + return Err(OpenAiError::StreamError); + } + + // spinner.ok(); + + let chunk = chunk.unwrap(); + + if let Some(choice) = chunk.choices.get(0) { + if let Some(delta) = &choice.delta { + if let Some(content) = &delta.content { + // print!("{}", content); + spinner.print(content); + } + } + } + } + + spinner.ok(); + Ok(openai::chats::Chat::default()) + } else { + log::debug!("Creating chat"); + self.api.create().await + } } } diff --git a/crates/b/src/main.rs b/crates/b/src/main.rs index 8050ec1..6cd2dba 100644 --- a/crates/b/src/main.rs +++ b/crates/b/src/main.rs @@ -41,7 +41,7 @@ async fn main() -> Result<(), CommandError> { } }; - let spinner = Spinner::new(cli.silent || cli.stream); + let mut spinner = Spinner::new(cli.silent || cli.stream); let result = match command.call().await { Ok(result) => { diff --git a/crates/b/src/utils.rs b/crates/b/src/utils.rs index 10cd63c..fe1ae01 100644 --- a/crates/b/src/utils.rs +++ b/crates/b/src/utils.rs @@ -3,9 +3,21 @@ use std::time::Duration; use indicatif::{ProgressBar, ProgressStyle}; +/// Spinner state +enum SpinnerState { + /// Spinner is running + Running, + /// Spinner is stopped + Stopped, + /// Spinner is silent + Silent, + /// Spinner is errored + Errored, +} + pub struct Spinner { progress_bar: ProgressBar, - silent: bool, + state: SpinnerState, } impl Spinner { @@ -15,7 +27,7 @@ impl Spinner { ProgressBar::hidden() } else { let progress_bar = ProgressBar::new_spinner(); - progress_bar.enable_steady_tick(Duration::from_millis(120)); + progress_bar.enable_steady_tick(Duration::from_millis(100)); progress_bar.set_style( ProgressStyle::with_template("{spinner:.magenta} {msg}") .unwrap() @@ -24,22 +36,36 @@ impl Spinner { progress_bar }; Self { - silent, + state: if silent { + SpinnerState::Silent + } else { + SpinnerState::Running + }, progress_bar, } } + pub fn print(&mut self, msg: &str) { + if let SpinnerState::Running = self.state { + self.progress_bar.suspend(|| { + print!("{}", msg); + }); + } + } + /// Stops the spinner successfully - pub fn ok(&self) { - if !self.silent { - self.progress_bar.finish_and_clear(); + pub fn ok(&mut self) { + if let SpinnerState::Running = self.state { + self.state = SpinnerState::Stopped; + self.progress_bar.finish_and_clear() } } /// Stops the spinner with an error - pub fn err(&self, msg: &str) { - if !self.silent { + pub fn err(&mut self, msg: &str) { + if let SpinnerState::Running = self.state { self.progress_bar.abandon_with_message(msg.to_string()); + self.state = SpinnerState::Errored; } } } diff --git a/crates/openai/src/chats.rs b/crates/openai/src/chats.rs index b3a1bb2..107b599 100644 --- a/crates/openai/src/chats.rs +++ b/crates/openai/src/chats.rs @@ -277,10 +277,6 @@ impl ChatsApi { pub async fn create_stream( &self, ) -> Result>, error::OpenAi> { - if Some(true) == self.stream { - return Err(error::OpenAi::InvalidStream); - } - let mut api = &mut (*self).clone(); let min_available_tokens = api.min_available_tokens.unwrap_or(750);