From 361162be9b5360e6013861fc10b8741f5d4baa53 Mon Sep 17 00:00:00 2001 From: CloudBridge UY Date: Wed, 12 Apr 2023 12:45:02 -0300 Subject: [PATCH] feat: added the ability to read from stdin --- crates/b/src/chats.rs | 20 +++++++++++++++++++ crates/b/src/completions.rs | 33 +++++++++++++++++++++++++++++++- crates/openai/src/completions.rs | 10 ++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) diff --git a/crates/b/src/chats.rs b/crates/b/src/chats.rs index 4dfe8c7..38c8f12 100644 --- a/crates/b/src/chats.rs +++ b/crates/b/src/chats.rs @@ -1,4 +1,5 @@ use std::error::Error; +use std::io::Read; use async_trait::async_trait; use serde_either::SingleOrVec; @@ -34,10 +35,29 @@ impl ChatsCreateCommand { .expect("No API Key provided") .to_string(); let mut api = ChatsApi::new(api_key)?; + api.messages = vec![ChatMessage { content: prompt.to_owned().join(" "), role: "user".to_owned(), }]; + + let mut stdin = Vec::new(); + // Read from stdin if it's not a tty and don't forget to unlock `stdin` + { + let mut stdin_lock = std::io::stdin().lock(); + stdin_lock.read_to_end(&mut stdin)?; + } + + if !stdin.is_empty() { + api.messages.insert( + 0, + ChatMessage { + content: String::from_utf8_lossy(&stdin).to_string(), + role: "user".to_owned(), + }, + ); + } + api.model = model.to_owned(); api.max_tokens = max_tokens.to_owned(); api.n = *n; diff --git a/crates/b/src/completions.rs b/crates/b/src/completions.rs index f6dad9d..f236e5a 100644 --- a/crates/b/src/completions.rs +++ b/crates/b/src/completions.rs @@ -1,4 +1,7 @@ use std::error::Error; +use std::fmt::Write; +use std::io::Read; +use std::string::String; use async_trait::async_trait; use serde_either::SingleOrVec; @@ -38,7 +41,35 @@ impl CompletionsCreateCommand { .expect("No API key provided") .to_string(); let mut api = CompletionsApi::new(api_key)?; - api.prompt = Some(SingleOrVec::Vec(prompt.clone())); + + let mut stdin = Vec::new(); + // Read from stdin if it's not a tty and don't forget to unlock `stdin` + { + let mut stdin_lock = std::io::stdin().lock(); + stdin_lock.read_to_end(&mut stdin)?; + } + + if !stdin.is_empty() { + if prompt.len() == 0 { + api.prompt = Some(SingleOrVec::Single( + String::from_utf8_lossy(&stdin).to_string(), + )); + } else { + let mut first = String::new(); + write!( + first, + "{}\n{}", + String::from_utf8_lossy(&stdin).to_string(), + prompt.first().unwrap().clone(), + )?; + let mut clone = prompt.clone().iter().skip(1).cloned().collect::>(); + clone.insert(0, first); + api.prompt = Some(SingleOrVec::Vec(clone)); + } + } else { + api.prompt = Some(SingleOrVec::Vec(prompt.clone())); + } + api.model = model.to_string(); api.max_tokens = *max_tokens; api.n = *n; diff --git a/crates/openai/src/completions.rs b/crates/openai/src/completions.rs index 9c86de9..be79ff2 100644 --- a/crates/openai/src/completions.rs +++ b/crates/openai/src/completions.rs @@ -143,6 +143,14 @@ impl CompletionsApi { if let Some(_) = &self.echo { return Err(error::OpenAi::InvalidSuffix); } + + // Can't run 'suffix' with multiple prompts + if let Some(SingleOrVec::Vec(prompts)) = &self.prompt { + if prompts.len() > 1 { + return Err(error::OpenAi::InvalidSuffix); + } + } + self.suffix = Some(suffix); log::debug!("Set suffix to {:?}", &self.suffix); @@ -311,6 +319,8 @@ impl CompletionsApi { } }; + log::debug!("Response body: {}", body); + let body: Completions = match serde_json::from_str(&body) { Ok(body) => body, Err(e) => {