diff --git a/demo/one-api.ipynb b/demo/one-api.ipynb index d7d3fe3..0b17457 100644 --- a/demo/one-api.ipynb +++ b/demo/one-api.ipynb @@ -2,9 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# 导入变量,或者从环境变量中加载\n", "from dotenv import load_dotenv\n", @@ -15,134 +26,242 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 用户管理" + "设置变量:\n", + "\n", + "```bash\n", + "# ONE API URL\n", + "ONE_API_BASE_URL=\n", + "# ACCESS TOKEN at https://{one-api-url}/panel/profile\n", + "ONE_API_ACCESS_TOKEN=\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 渠道管理" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from one_api_cli import get_users, get_user, update_user, delete_user, create_user" + "from one_api_cli import Channel" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "23 perry_bailey\n", + "22 margarita_caban\n", + "21 jessica_hamilton\n" + ] + } + ], "source": [ - "# 查看用户\n", - "users = get_users()\n", - "for user in users:\n", - " print(user['id'], user['username'])\n", - "print(get_user(1))\n", - "print(get_user(100)) # 不存在的用户" + "# 查看渠道\n", + "channels = Channel.get_channels()\n", + "for channel in channels[:3]:\n", + " print(channel.id, channel.name)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# 新增用户\n", - "username = \"test2\"\n", - "display_name = \"test\"\n", - "password = \"complicated_password\"\n", - "create_user(username, display_name, password)" + "# 新增渠道\n", + "name = \"test_channel\"\n", + "Channel.create(name=name, key='sk-123', base_url = 'https://api.openai.com', models='gpt-test')" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': 24, 'type': 1, 'key': '', 'status': 1, 'name': 'test_channel', 'weight': 0, 'created_time': 1723269632, 'test_time': 0, 'response_time': 0, 'base_url': 'https://api.openai.com', 'other': '', 'balance': 0, 'balance_updated_time': 0, 'models': 'gpt-test', 'group': 'default', 'used_quota': 0, 'model_mapping': '', 'priority': 0, 'config': '{}'}\n", + "{'id': 24, 'type': 1, 'key': '', 'status': 1, 'name': 'new_channel', 'weight': 0, 'created_time': 1723269632, 'test_time': 0, 'response_time': 0, 'base_url': 'https://api.openai.com', 'other': '', 'balance': 0, 'balance_updated_time': 0, 'models': 'gpt-test', 'group': 'default', 'used_quota': 0, 'model_mapping': '', 'priority': 0, 'config': '{}'}\n" + ] + } + ], "source": [ - "# 修改用户信息\n", - "username = \"test2\"\n", - "new_password = \"new_password_233\"\n", - "update_user(username, password=new_password) # 也可以指定 ID" + "# 查看刚刚创建的渠道\n", + "new_channel = Channel.get_channels(page=0)[0]\n", + "print(new_channel.dumps())\n", + "\n", + "# 修改渠道\n", + "new_name = \"new_channel\"\n", + "new_channel.update(name=new_name)\n", + "print(new_channel.dumps())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# 删除用户\n", - "delete_user(username) # 或使用 ID" + "# 删除频道\n", + "new_channel.delete(confim=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 渠道管理" + "## 用户管理" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "from one_api_cli import get_channels, get_channel, create_channel, update_channel, delete_channel" + "from one_api_cli import get_users, get_user, update_user, delete_user, create_user" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "9 test2\n", + "1 rexwang\n", + "1 rexwang\n" + ] + } + ], "source": [ - "# 查看频道\n", - "channels = get_channels()\n", - "for channel in channels:\n", - " print(channel['id'], channel['name'])\n", - "print(get_channel(1))\n", - "print(get_channel(100)) # 不存在的频道" + "# 查看用户\n", + "users = get_users()\n", + "for user in users:\n", + " print(user['id'], user['username'])\n", + "user = get_user(1)\n", + "print(user['id'], user['username'])" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-08-10 13:54:04.080\u001b[0m | \u001b[31m\u001b[1mERROR \u001b[0m | \u001b[36mone_api_cli.account\u001b[0m:\u001b[36mcreate_user\u001b[0m:\u001b[36m82\u001b[0m - \u001b[31m\u001b[1mUNIQUE constraint failed: users.username\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# 新增频道\n", - "name = \"test_channel\"\n", - "create_channel(name, 'sk-123', 'https://api.openai.com', 'gpt-test')" + "# 新增用户\n", + "username = \"test2\"\n", + "display_name = \"test\"\n", + "password = \"complicated_password\"\n", + "create_user(username, display_name, password)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# 修改频道信息\n", - "channel_id = [channel['id'] for channel in channels if channel['name'] == name][0]\n", - "new_name = \"new_channel\"\n", - "update_channel(channel_id, name=new_name)" + "# 修改用户信息\n", + "username = \"test2\"\n", + "userid = [user['id'] for user in users if user['username'] == username][0]\n", + "new_password = \"new_password_233\"\n", + "update_user(userid, password=new_password)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# 删除频道\n", - "delete_channel(channel_id)" + "# 删除用户\n", + "delete_user(username) # 或使用 ID" ] } ], "metadata": { "kernelspec": { - "display_name": "base", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -160,5 +279,5 @@ } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index c19c950..0000000 --- a/setup.cfg +++ /dev/null @@ -1,9 +0,0 @@ -[bdist_wheel] -universal = 1 - -[flake8] -exclude = docs -[tool:pytest] -addopts = --ignore=setup.py - - diff --git a/src/one_api_cli/__init__.py b/src/one_api_cli/__init__.py index ddc788c..c72b0ec 100644 --- a/src/one_api_cli/__init__.py +++ b/src/one_api_cli/__init__.py @@ -5,4 +5,4 @@ __version__ = '0.2.0' from .account import get_users, update_user, get_user, create_user, delete_user -from .channel import get_channels, update_channel, delete_channel, create_channel, get_channel \ No newline at end of file +from .channel import Channel \ No newline at end of file diff --git a/src/one_api_cli/channel.py b/src/one_api_cli/channel.py index 9e2d699..cf353e9 100644 --- a/src/one_api_cli/channel.py +++ b/src/one_api_cli/channel.py @@ -1,177 +1,135 @@ import requests -from .constant import base_url, headers +from .constant import base_url, headers, default_channel_data from loguru import logger -class Channel(): - id: int - """The ID of the channel.""" - - type: int - """The type of the channel. Default to OpenAI, i.e. channel type 1.""" - - key: str - """The api key of the channel.""" - - name: str - """The display name of the channel.""" - - base_url: str - """The base URL of the channel.""" - - models: str - """The models of the channel, separated by commas.""" - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - def update(self, **kwargs): - self.__dict__.update(kwargs) - - channel_url = f"{base_url}/api/channel" -def get_channels(): - """ - Retrieve a list of channels. - - Returns: - list: A list of channel dictionaries. +class Channel(): """ - try: - response = requests.get(channel_url, headers=headers) - response.raise_for_status() - msg = response.json() - if not msg['success']: - logger.error(msg['message']) - return {} - return Channel(**msg['data']) - except requests.RequestException as e: - logger.error(f"Error fetching channels: {e}") - return [] - -def get_channel(channel_id:int)->dict: + A class to represent a channel. """ - Retrieve the data of a channel. - - Returns: - dict: A channel dictionary. - """ - channel_id_url = f"{channel_url}/{channel_id}" - try: - response = requests.get(channel_id_url, headers=headers) - response.raise_for_status() - msg = response.json() - if not msg['success']: - logger.error(msg['message']) - return {} - return Channel(**msg['data']) - except requests.RequestException as e: - logger.error(f"Error fetching channel: {e}") - return {} -def update_channel(channel_id, **options) -> bool: - """ - Update a channel's data. + def __init__(self, id:int): + data = self.fetch_channel_data(id) + if not data: + raise ValueError(f"Channel with ID {id} not found.") + self.__dict__.update(data) - Args: - channel_id (int): The ID of the channel. - **options: The data to update. + def dumps(self): + return self.__dict__.copy() - Returns: - bool: True if the channel is updated successfully, False otherwise. - """ + @classmethod + def fetch_channel_data(cls, id:int) -> dict: + """ + Fetch the data of a channel. + + Args: + id (int): The ID of the channel. + + Returns: + dict: The data of the channel. + """ + channel_id_url = f"{channel_url}/{id}" + response = cls._make_request('get', channel_id_url) + if not response['success']: + logger.error(response['message']) + raise ValueError(f"Channel with ID {id} not found.") + return response['data'] + + @staticmethod + def get_channels(page=None): + """ + Retrieve a list of channels. + + Returns: + list: A list of channel dictionaries. + """ + if page is not None: + suffix = f"/?p={page}" + channel_data = Channel._make_request('get', channel_url + suffix)['data'] + else: + i = 0 + channel_data = [] + while True: + suffix = f"?p={i}" + response = Channel._make_request('get', channel_url + suffix) + new_data = response['data'] + if not new_data: + break + channel_data.extend(new_data) + i += 1 + return [Channel.from_data(**data) for data in channel_data] - try: - channel_data = get_channel(channel_id) - if not channel_data: - logger.error(f"Channel with ID {channel_id} not found.") + @staticmethod + def from_data(**data:dict): + """ + Create a channel object from data. + + Args: + **data: The data of the channel. + + Returns: + Channel: A channel object. + """ + channel = Channel.__new__(Channel) + channel.__dict__.update(data) + return channel + + def update(self, **channel_data): + """Update the channel data.""" + data = self.__dict__.copy() + data.update(channel_data) + response = self._make_request('put', channel_url, json=data) + if not response['success']: + logger.error(response['message']) return False - channel_data.update(options) - response = requests.put(channel_url, headers=headers, json=channel_data) - response.raise_for_status() - msg = response.json() + self.__dict__.update(channel_data) + return True + + def delete(self, confim:bool=True): + """Delete the channel.""" + if confim: + logger.warning(f"Deleting channel {self.name} with ID {self.id}") + c = input("Are you sure? (y/n): ") + if c.lower() != 'y': return False + channel_id_url = f"{channel_url}/{self.id}" + msg = self._make_request('delete', channel_id_url) if not msg['success']: logger.error(msg['message']) return False return True - except requests.RequestException as e: - logger.error(f"Error updating channel: {e}") - return False - -def delete_channel(channel_id) -> bool: - """ - Delete a channel. - - Args: - channel_id (int): The ID of the channel. - Returns: - bool: True if the channel is deleted successfully, False otherwise. - """ - channel_id_url = f"{channel_url}/{channel_id}" - try: - response = requests.delete(channel_id_url, headers=headers) - response.raise_for_status() - msg = response.json() - if not msg['success']: - logger.error(msg['message']) + @staticmethod + def create(**channel_data): + """Create a new channel. + + Args: + name (str): The name of the channel. + key (str): The api key of the channel. + base_url (str): The base URL of the channel. + models (list): The models of the channel. + """ + data = default_channel_data.copy() + data.update(channel_data) + assert None not in data.values(), "Missing required fields" + response = Channel._make_request('post', channel_url, json=data) + if not response['success']: + logger.error(response['message']) return False return True - except requests.RequestException as e: - logger.error(f"Error deleting channel: {e}") - return False - -def create_channel( - name, key, base_url, models, - type: int = 1, - other: str = '', - model_mapping: str = '', - groups: list = ['default'], - config: str = '{}', - is_edit: bool = False, - group: str = 'default' -) -> bool: - """ - Create a new channel. - Args: - name (str): The name of the channel. - key (str): The key of the channel. - base_url (str): The base URL of the channel. - models (list): The models of the channel. - type (int): The type of the channel. Default to OpenAI. - other (str): Other information of the channel. - model_mapping (str): The model mapping of the channel. - groups (list): The groups of the channel. - config (str): The config of the channel. - is_edit (bool): Whether the channel can be edited. - group (str): The group of the channel. + @staticmethod + def _make_request(method:str, url:str, **kwargs): + try: + response = requests.request(method, url, headers=headers, **kwargs) + response.raise_for_status() + return response.json() + except requests.RequestException as e: + logger.error(f"Error making request: {e}") + return {} - Returns: - bool: True if the channel is created successfully, False otherwise. - """ + def __repr__(self) -> str: + return f"Channel({self.__dict__})" - data = { - 'name': name, - 'key': key, - 'base_url': base_url, - 'models': models, - 'type': type, - 'other': other, - 'model_mapping': model_mapping, - 'groups': groups, - 'config': config, - 'is_edit': is_edit, - 'group': group - } - try: - response = requests.post(channel_url, headers=headers, json=data) - response.raise_for_status() - msg = response.json() - if not msg['success']: - logger.error(msg['message']) - return False - return True - except requests.RequestException as e: - logger.error(f"Error creating channel: {e}") - return False \ No newline at end of file + def __str__(self) -> str: + return f"Channel({self.__dict__})" diff --git a/src/one_api_cli/constant.py b/src/one_api_cli/constant.py index 752b999..b6ea09f 100644 --- a/src/one_api_cli/constant.py +++ b/src/one_api_cli/constant.py @@ -9,6 +9,20 @@ assert base_url, "ONE_API_BASE_URL is not set" assert access_token or section_token, "Either ONE_API_ACCESS_TOKEN or ONE_API_SECTION_TOKEN must be set" +default_channel_data = { + "name": None, + "key": None, + "base_url": None, + "models": None, + "type": 1, + "other": "", + "model_mapping": "", + "groups": ["default"], + "config": "{}", + "is_edit": False, + "group": "default" +} + # Headers if access_token: headers = {