Skip to content

Commit

Permalink
feat(c): added support for Claude v2 and the new Anthropic streaming api
Browse files Browse the repository at this point in the history
feat(anthropic): add version header

feat(c): add Claude2 model

- Added `Claude2` model enum variant

fix(c): handle stop reasons
- Added logic to handle `stop_reason` in streaming response

fix(c): make max_tokens_to_sample required
- Made `max_tokens_to_sample` required and changed type to `u32`

fix(c): default max_supported_tokens
- Defaulted `max_supported_tokens` to 4096 and calculated `max_tokens_to_sample` if not provided

fix(c): remove spinner logic
- Removed spinner logic from `openai.rs`

fix(c): remove truncated field
- Removed `truncated` field from `Chunk`

refactor(c): update logic
- Updated `CompleteCreateCommand` in `anthropic.rs`
  • Loading branch information
cloudbridgeuy committed Aug 4, 2023
1 parent 3fe3464 commit 921be77
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 26 deletions.
3 changes: 3 additions & 0 deletions crates/anthropic/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ fn create_headers(api_key: String) -> Result<HeaderMap> {
HeaderValue::from_str(api_key.as_str()).context("can't create authorization header")?;
let content_type =
HeaderValue::from_str("application/json").context("can't create content-type header")?;
let version =
HeaderValue::from_str("2023-06-01").context("can't create anthropic-version header")?;

headers.insert("anthropic-version", version);
headers.insert("X-API-Key", authorization);
headers.insert("Content-Type", content_type);

Expand Down
7 changes: 1 addition & 6 deletions crates/b/src/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ impl CommandHandle<Response> for CompleteCreateCommand {
type CallError = CommandError;

async fn call(&self) -> Result<Response, Self::CallError> {
match self.api.create().await {
Ok(response) => Ok(response),
Err(e) => Err(CommandError::AnthropicError {
body: e.to_string(),
}),
}
todo!()
}
}
42 changes: 31 additions & 11 deletions crates/c/src/commands/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod session;
#[serde(rename_all = "kebab-case")]
pub enum Model {
#[default]
Claude2,
ClaudeV1,
ClaudeV1_100k,
ClaudeInstantV1,
Expand All @@ -28,6 +29,7 @@ pub enum Model {
impl Model {
pub fn as_str(&self) -> &'static str {
match self {
Self::Claude2 => "claude-2",
Self::ClaudeV1 => "claude-v1",
Self::ClaudeV1_100k => "claude-v1-100k",
Self::ClaudeInstantV1 => "claude-instant-v1",
Expand All @@ -40,7 +42,7 @@ impl Model {
pub struct CompleteRequestBody {
pub model: String,
pub prompt: String,
pub max_tokens_to_sample: Option<u32>,
pub max_tokens_to_sample: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "std::ops::Not::not")]
Expand Down Expand Up @@ -216,21 +218,16 @@ pub async fn run(mut options: Options) -> Result<()> {
tokio::pin!(chunks);

while let Some(chunk) = chunks.next().await {
// Stop the spinner.
spinner.stop();

if chunk.is_err() {
color_eyre::eyre::bail!("Error streaming response: {:?}", chunk);
}

let chunk = chunk.unwrap();
tracing::event!(tracing::Level::DEBUG, "Received chunk... {:?}", chunk);

let len = acc.len();
let partial = chunk.completion[len..].to_string();

print!("{partial}");
spinner.print(&chunk.completion);

acc = chunk.completion;
acc.push_str(&chunk.completion);
}
// Add a new line at the end to make sure the prompt is on a new line.
println!();
Expand Down Expand Up @@ -306,19 +303,23 @@ async fn complete_stream(session: &Session) -> Result<impl Stream<Item = Result<
tracing::event!(tracing::Level::INFO, "Open SSE stream...");
}
reqwest_eventsource::Event::Message(message) => {
tracing::event!(tracing::Level::DEBUG, "message: {:?}", message);
tracing::event!(tracing::Level::INFO, "message: {:?}", message);

if message.data == "[DONE]" {
break;
}

if message.event != "completion" {
continue;
}

let chunk: Chunk = match serde_json::from_str(&message.data) {
Ok(chunk) => chunk,
Err(e) => {
tracing::event!(tracing::Level::ERROR, "e: {e}");
if tx
.send(Err(color_eyre::eyre::format_err!(
"Error parsing response: {e}"
"Error parsing event: {e}"
)))
.await
.is_err()
Expand All @@ -329,6 +330,25 @@ async fn complete_stream(session: &Session) -> Result<impl Stream<Item = Result<
}
};

if chunk.stop_reason.is_some() {
let stop_reason = chunk.stop_reason.clone().unwrap();

tracing::event!(
tracing::Level::INFO,
"Stopping stream due to stop_reason: {stop_reason}",
);

if stop_reason == "stop_sequence" {
tracing::event!(
tracing::Level::INFO,
"Found stop sequence: {}",
&chunk.stop.unwrap()
);
}

break;
}

tracing::event!(tracing::Level::DEBUG, "chunk: {:?}", chunk);
if tx.send(Ok(chunk)).await.is_err() {
return;
Expand Down
1 change: 0 additions & 1 deletion crates/c/src/commands/anthropic/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pub struct Chunk {
pub completion: String,
pub stop_reason: Option<String>,
pub model: String,
pub truncated: bool,
pub stop: Option<String>,
}

Expand Down
8 changes: 4 additions & 4 deletions crates/c/src/commands/anthropic/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct Options {
/// Claude and Claude Instant.
pub model: String,
/// A maximum number of tokens to generate before stopping.
pub max_tokens_to_sample: Option<u32>,
pub max_tokens_to_sample: u32,
/// Claude models stop on `\n\nHuman:`, and may include additional built-in stops sequences
/// in the future. By providing the `stop_sequences` parameter, you may include additional
/// strings that will cause the model to stop generation.
Expand Down Expand Up @@ -119,7 +119,7 @@ impl Session {
}

if options.max_tokens_to_sample.is_some() {
self.options.max_tokens_to_sample = options.max_tokens_to_sample;
self.options.max_tokens_to_sample = options.max_tokens_to_sample.unwrap_or(1000);
}

if options.max_supported_tokens.is_some() {
Expand Down Expand Up @@ -177,8 +177,8 @@ impl Session {

/// Returns a valid completion prompt from the list of messages.
pub fn complete_prompt(&self) -> Result<String> {
let max = self.options.max_supported_tokens.unwrap_or(4096)
- self.options.max_tokens_to_sample.unwrap_or(1000);
let max =
self.options.max_supported_tokens.unwrap_or(4096) - self.options.max_tokens_to_sample;
let mut messages = self.messages.clone();
messages.push(Message::new("".to_string(), Role::Assistant, false));

Expand Down
5 changes: 1 addition & 4 deletions crates/c/src/commands/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,6 @@ pub async fn run(mut options: Options) -> Result<()> {
tokio::pin!(chunks);

while let Some(chunk) = chunks.next().await {
// Stop the spinner.
spinner.stop();

if chunk.is_err() {
color_eyre::eyre::bail!("Error streaming response: {:?}", chunk);
}
Expand All @@ -295,7 +292,7 @@ pub async fn run(mut options: Options) -> Result<()> {
if let Some(delta) = &choice.delta {
if let Some(content) = &delta.content {
acc.push_str(content);
print!("{}", content);
spinner.print(content);
}
}
}
Expand Down

0 comments on commit 921be77

Please sign in to comment.