Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable custom api_base for most clients #793

Merged
merged 1 commit into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ clients:
# See https://ai.google.dev/docs
- type: gemini
api_key: xxx
api_base: https://generativelanguage.googleapis.com/v1beta # Optional
patch:
chat_completions:
'.*':
Expand All @@ -141,6 +142,7 @@ clients:
# See https://docs.anthropic.com/claude/reference/getting-started-with-the-api
- type: claude
api_key: sk-ant-xxx
api_base: https://api.anthropic.com/v1 # Optional

# See https://docs.mistral.ai/
- type: openai-compatible
Expand All @@ -151,18 +153,19 @@ clients:
# See https://docs.cohere.com/docs/the-cohere-platform
- type: cohere
api_key: xxx
api_base: https://api.cohere.ai/v1 # Optional

# See https://docs.perplexity.ai/docs/getting-started
- type: openai-compatible
name: perplexity
api_base: https://api.perplexity.ai
api_key: pplx-xxx
api_base: https://api.perplexity.ai

# See https://console.groq.com/docs/quickstart
- type: openai-compatible
name: groq
api_base: https://api.groq.com/openai/v1
api_key: gsk_xxx
api_base: https://api.groq.com/openai/v1

# See https://github.com/jmorganca/ollama
- type: ollama
Expand All @@ -179,8 +182,8 @@ clients:

# See https://learn.microsoft.com/en-us/azure/ai-services/openai/chatgpt-quickstart
- type: azure-openai
api_base: https://{RESOURCE}.openai.azure.com
api_key: xxx
api_base: https://{RESOURCE}.openai.azure.com
models:
- name: gpt-4o # Model deployment name
max_input_tokens: 128000
Expand Down Expand Up @@ -219,6 +222,7 @@ clients:
- type: cloudflare
account_id: xxx
api_key: xxx
api_base: https://api.cloudflare.com/client/v4 # Optional

# See https://replicate.com/docs
- type: replicate
Expand All @@ -232,66 +236,70 @@ clients:
# See https://help.aliyun.com/zh/dashscope/
- type: qianwen
api_key: sk-xxx
api_base: https://dashscope.aliyuncs.com/api/v1 # Optional

# See https://platform.moonshot.cn/docs/intro
- type: openai-compatible
name: moonshot
api_base: https://api.moonshot.cn/v1
api_key: sk-xxx
api_base: https://api.moonshot.cn/v1

# See https://platform.deepseek.com/api-docs/
- type: openai-compatible
name: deepseek
api_key: sk-xxx
api_base: https://api.deepseek.com

# See https://open.bigmodel.cn/dev/howuse/introduction
- type: openai-compatible
name: zhipuai
api_key: xxx
api_base: https://open.bigmodel.cn/api/paas/v4

# See https://platform.lingyiwanwu.com/docs
- type: openai-compatible
name: lingyiwanwu
api_key: xxx
api_base: https://api.lingyiwanwu.com/v1

# See https://deepinfra.com/docs
- type: openai-compatible
name: deepinfra
api_base: https://api.deepinfra.com/v1/openai
api_key: xxx
api_base: https://api.deepinfra.com/v1/openai

# See https://readme.fireworks.ai/docs/quickstart
- type: openai-compatible
name: fireworks
api_base: https://api.fireworks.ai/inference/v1
api_key: xxx
api_base: https://api.fireworks.ai/inference/v1

# See https://openrouter.ai/docs#quick-start
- type: openai-compatible
name: openrouter
api_base: https://openrouter.ai/api/v1
api_key: xxx
api_base: https://openrouter.ai/api/v1

# See https://octo.ai/docs/getting-started/quickstart
- type: openai-compatible
name: octoai
api_base: https://text.octoai.run/v1
api_key: xxx
api_base: https://text.octoai.run/v1

# See https://docs.together.ai/docs/quickstart
- type: openai-compatible
name: together
api_base: https://api.together.xyz/v1
api_key: xxx
api_base: https://api.together.xyz/v1

# See https://jina.ai
- type: openai-compatible
name: jina
api_base: https://api.jina.ai/v1
api_key: xxx
api_base: https://api.jina.ai/v1

# See https://docs.voyageai.com/docs/introduction
- type: openai-compatible
name: voyageai
api_key: xxx
api_base: https://api.voyageai.ai/v1
api_key: xxx
10 changes: 8 additions & 2 deletions src/client/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://api.anthropic.com/v1/messages";
const API_BASE: &str = "https://api.anthropic.com/v1";

#[derive(Debug, Clone, Deserialize)]
pub struct ClaudeConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -19,6 +20,7 @@ pub struct ClaudeConfig {

impl ClaudeClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -40,10 +42,14 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key().ok();
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/messages", api_base.trim_end_matches('/'));
let body = claude_build_chat_completions_body(data, &self_.model)?;

let mut request_data = RequestData::new(API_BASE, body);
let mut request_data = RequestData::new(url, body);

request_data.header("anthropic-version", "2023-06-01");
if let Some(api_key) = api_key {
Expand Down
8 changes: 7 additions & 1 deletion src/client/cloudflare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const API_BASE: &str = "https://api.cloudflare.com/client/v4";
pub struct CloudflareConfig {
pub name: Option<String>,
pub account_id: Option<String>,
pub api_base: Option<String>,
pub api_key: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
Expand All @@ -21,6 +22,7 @@ pub struct CloudflareConfig {
impl CloudflareClient {
config_get_fn!(account_id, get_account_id);
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 2] = [
("account_id", "Account ID:", true, PromptKind::String),
Expand All @@ -45,9 +47,13 @@ fn prepare_chat_completions(
) -> Result<RequestData> {
let account_id = self_.get_account_id()?;
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!(
"{API_BASE}/accounts/{account_id}/ai/run/{}",
"{}/accounts/{account_id}/ai/run/{}",
api_base.trim_end_matches('/'),
self_.model.name()
);

Expand Down
27 changes: 20 additions & 7 deletions src/client/cohere.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
use super::*;
use super::openai_compatible::*;
use super::*;

use anyhow::{bail, Context, Result};
use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const CHAT_COMPLETIONS_API_URL: &str = "https://api.cohere.ai/v1/chat";
const EMBEDDINGS_API_URL: &str = "https://api.cohere.ai/v1/embed";
const RERANK_API_URL: &str = "https://api.cohere.ai/v1/rerank";
const API_BASE: &str = "https://api.cohere.ai/v1";

#[derive(Debug, Clone, Deserialize, Default)]
pub struct CohereConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -22,6 +21,7 @@ pub struct CohereConfig {

impl CohereClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -43,10 +43,14 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/chat", api_base.trim_end_matches('/'));
let body = build_chat_completions_body(data, &self_.model)?;

let mut request_data = RequestData::new(CHAT_COMPLETIONS_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand All @@ -55,6 +59,11 @@ fn prepare_chat_completions(

fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/embed", api_base.trim_end_matches('/'));

let input_type = match data.query {
true => "search_query",
Expand All @@ -67,7 +76,7 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<Requ
"input_type": input_type,
});

let mut request_data = RequestData::new(EMBEDDINGS_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand All @@ -76,10 +85,14 @@ fn prepare_embeddings(self_: &CohereClient, data: EmbeddingsData) -> Result<Requ

fn prepare_rerank(self_: &CohereClient, data: RerankData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{}/rerank", api_base.trim_end_matches('/'));
let body = generic_build_rerank_body(data, &self_.model);

let mut request_data = RequestData::new(RERANK_API_URL, body);
let mut request_data = RequestData::new(url, body);

request_data.bearer_auth(api_key);

Expand Down
21 changes: 18 additions & 3 deletions src/client/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use reqwest::RequestBuilder;
use serde::Deserialize;
use serde_json::{json, Value};

const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models/";
const API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta";

#[derive(Debug, Clone, Deserialize, Default)]
pub struct GeminiConfig {
pub name: Option<String>,
pub api_key: Option<String>,
pub api_base: Option<String>,
#[serde(default)]
pub models: Vec<ModelData>,
pub patch: Option<RequestPatch>,
Expand All @@ -20,6 +21,7 @@ pub struct GeminiConfig {

impl GeminiClient {
config_get_fn!(api_key, get_api_key);
config_get_fn!(api_base, get_api_base);

pub const PROMPTS: [PromptAction<'static>; 1] =
[("api_key", "API Key:", true, PromptKind::String)];
Expand All @@ -41,13 +43,22 @@ fn prepare_chat_completions(
data: ChatCompletionsData,
) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let func = match data.stream {
true => "streamGenerateContent",
false => "generateContent",
};

let url = format!("{API_BASE}{}:{}?key={}", self_.model.name(), func, api_key);
let url = format!(
"{}/models/{}:{}?key={}",
api_base.trim_end_matches('/'),
self_.model.name(),
func,
api_key
);

let body = gemini_build_chat_completions_body(data, &self_.model)?;

Expand All @@ -58,9 +69,13 @@ fn prepare_chat_completions(

fn prepare_embeddings(self_: &GeminiClient, data: EmbeddingsData) -> Result<RequestData> {
let api_key = self_.get_api_key()?;
let api_base = self_
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!(
"{API_BASE}{}:embedContent?key={}",
"{}/models/{}:embedContent?key={}",
api_base.trim_end_matches('/'),
self_.model.name(),
api_key
);
Expand Down
2 changes: 1 addition & 1 deletion src/client/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn prepare_chat_completions(
.get_api_base()
.unwrap_or_else(|_| API_BASE.to_string());

let url = format!("{api_base}/chat/completions");
let url = format!("{}/chat/completions", api_base.trim_end_matches('/'));

let body = openai_build_chat_completions_body(data, &self_.model);

Expand Down
Loading