Skip to content

Commit

Permalink
feat: enable custom api_base for most clients (#793)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden committed Aug 17, 2024
1 parent 580ed6b commit 669f2c6
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 42 deletions.
30 changes: 19 additions & 11 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ clients:
# See https://ai.google.dev/docs
- type: gemini
api_key: xxx
api_base: https://generativelanguage.googleapis.com/v1beta # Optional
patch:
chat_completions:
'.*':
Expand All @@ -141,6 +142,7 @@ clients:
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
api_key: sk-ant-xxx
api_base: https://api.anthropic.com/v1 # Optional

# See https://docs.mistral.ai/
- type: openai-compatible
Expand All @@ -151,18 +153,19 @@ clients:
# See https://docs.cohere.com/docs/the-cohere-platform
- type: cohere
api_key: xxx
api_base: https://api.cohere.ai/v1 # Optional

# See https://docs.perplexity.ai/docs/getting-started
- type: openai-compatible
name: perplexity
api_base: https://api.perplexity.ai
api_key: pplx-xxx
api_base: https://api.perplexity.ai

# See https://console.groq.com/docs/quickstart
- type: openai-compatible
name: groq
api_base: https://api.groq.com/openai/v1
api_key: gsk_xxx
api_base: https://api.groq.com/openai/v1

# See https://github.com/jmorganca/ollama
- type: ollama
Expand All @@ -179,8 +182,8 @@ clients:

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://{RESOURCE}.openai.azure.com
api_key: xxx
api_base: https://{RESOURCE}.openai.azure.com
models:
- name: gpt-4o # Model deployment name
max_input_tokens: 128000
Expand Down Expand Up @@ -219,6 +222,7 @@ clients:
- type: cloudflare
account_id: xxx
api_key: xxx
api_base: https://api.cloudflare.com/client/v4 # Optional

# See https://replicate.com/docs
- type: replicate
Expand All @@ -232,66 +236,70 @@ clients:
# See https://help.aliyun.com/zh/dashscope/
- type: qianwen
api_key: sk-xxx
api_base: https://dashscope.aliyuncs.com/api/v1 # Optional

# See https://platform.moonshot.cn/docs/intro
- type: openai-compatible
name: moonshot
api_base: https://api.moonshot.cn/v1
api_key: sk-xxx
api_base: https://api.moonshot.cn/v1

# See https://platform.deepseek.com/api-docs/
- type: openai-compatible
name: deepseek
api_key: sk-xxx
api_base: https://api.deepseek.com

# See https://open.bigmodel.cn/dev/howuse/introduction
- type: openai-compatible
name: zhipuai
api_key: xxx
api_base: https://open.bigmodel.cn/api/paas/v4

# See https://platform.lingyiwanwu.com/docs
- type: openai-compatible
name: lingyiwanwu
api_key: xxx
api_base: https://api.lingyiwanwu.com/v1

# See https://deepinfra.com/docs
- type: openai-compatible
name: deepinfra
api_base: https://api.deepinfra.com/v1/openai
api_key: xxx
api_base: https://api.deepinfra.com/v1/openai

# See https://readme.fireworks.ai/docs/quickstart
- type: openai-compatible
name: fireworks
api_base: https://api.fireworks.ai/inference/v1
api_key: xxx
api_base: https://api.fireworks.ai/inference/v1

# See https://openrouter.ai/docs#quick-start
- type: openai-compatible
name: openrouter
api_base: https://openrouter.ai/api/v1
api_key: xxx
api_base: https://openrouter.ai/api/v1

# See https://octo.ai/docs/getting-started/quickstart
- type: openai-compatible
name: octoai
api_base: https://text.octoai.run/v1
api_key: xxx
api_base: https://text.octoai.run/v1

# See https://docs.together.ai/docs/quickstart
- type: openai-compatible
name: together
api_base: https://api.together.xyz/v1
api_key: xxx
api_base: https://api.together.xyz/v1

# See https://jina.ai
- type: openai-compatible
name: jina
api_base: https://api.jina.ai/v1
api_key: xxx
api_base: https://api.jina.ai/v1

# See https://docs.voyageai.com/docs/introduction
- type: openai-compatible
name: voyageai
api_key: xxx
api_base: https://api.voyageai.ai/v1
api_key: xxx
10 changes: 8 additions & 2 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const API_BASE: &str = "https://api.anthropic.com/v1";

#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -19,6 +20,7 @@ pub struct ClaudeConfig {

impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -40,10 +42,14 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/messages", api_base.trim_end_matches('/'));
let body = claude_build_chat_completions_body(data, &self_.model)?;

let mut request_data = RequestData::new(API_BASE, body);
let mut request_data = RequestData::new(url, body);

request_data.header("anthropic-version", "2023-06-01");
if let Some(api_key) = api_key {
Expand Down
8 changes: 7 additions & 1 deletion src/client/cloudflare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const API_BASE: &str = "https://api.cloudflare.com/client/v4";
pub struct CloudflareConfig {
pub name: Option<String>,
pub account_id: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
Expand All @@ -21,6 +22,7 @@ pub struct CloudflareConfig {
impl CloudflareClient {
config_get_fn!(account_id, get_account_id);
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 2] = [
("account_id", "Account ID:", true, PromptKind::String),
Expand All @@ -45,9 +47,13 @@ fn prepare_chat_completions(
) -> Result<RequestData> {
let account_id = self_.get_account_id()?;
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
"{}/accounts/{account_id}/ai/run/{}",
api_base.trim_end_matches('/'),
self_.model.name()
);

Expand Down
27 changes: 20 additions & 7 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use super::*;
use super::openai_compatible::*;
use super::*;

use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const CHAT_COMPLETIONS_API_URL: &str = "https://api.cohere.ai/v1/chat";
const EMBEDDINGS_API_URL: &str = "https://api.cohere.ai/v1/embed";
const RERANK_API_URL: &str = "https://api.cohere.ai/v1/rerank";
const API_BASE: &str = "https://api.cohere.ai/v1";

#[derive(Debug, Clone, Deserialize, Default)]
pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -22,6 +21,7 @@ pub struct CohereConfig {

impl CohereClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -43,10 +43,14 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/chat", api_base.trim_end_matches('/'));
let body = build_chat_completions_body(data, &self_.model)?;

let mut request_data = RequestData::new(CHAT_COMPLETIONS_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand All @@ -55,6 +59,11 @@ fn prepare_chat_completions(

fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/embed", api_base.trim_end_matches('/'));

let input_type = match data.query {
true => "search_query",
Expand All @@ -67,7 +76,7 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<Requ
"input_type": input_type,
});

let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand All @@ -76,10 +85,14 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<Requ

fn prepare_rerank(self_: &CohereClient, data: RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/rerank", api_base.trim_end_matches('/'));
let body = generic_build_rerank_body(data, &self_.model);

let mut request_data = RequestData::new(RERANK_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand Down
21 changes: 18 additions & 3 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/";
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";

#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -20,6 +21,7 @@ pub struct GeminiConfig {

impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -41,13 +43,22 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};

let url = format!("{API_BASE}{}:{}?key={}", self_.model.name(), func, api_key);
let url = format!(
"{}/models/{}:{}?key={}",
api_base.trim_end_matches('/'),
self_.model.name(),
func,
api_key
);

let body = gemini_build_chat_completions_body(data, &self_.model)?;

Expand All @@ -58,9 +69,13 @@ fn prepare_chat_completions(

fn prepare_embeddings(self_: &GeminiClient, data: EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!(
"{API_BASE}{}:embedContent?key={}",
"{}/models/{}:embedContent?key={}",
api_base.trim_end_matches('/'),
self_.model.name(),
api_key
);
Expand Down
2 changes: 1 addition & 1 deletion src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn prepare_chat_completions(
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{api_base}/chat/completions");
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));

let body = openai_build_chat_completions_body(data, &self_.model);

Expand Down
Loading

0 comments on commit 669f2c6

Please sign in to comment.