From 4d0d878ff70cb6e4f3bb76595df3ed01d711d316 Mon Sep 17 00:00:00 2001 From: liblaf <30631553+liblaf@users.noreply.github.com> Date: Sat, 15 Jun 2024 15:26:54 +0800 Subject: [PATCH] fix: update release-please config and refactor CLI imports - Updated release-please configuration to use more descriptive emojis and hide less relevant sections. - Refactored CLI imports for better readability and maintainability. - Added new file for strict type checking configuration in Pyright. - Updated VSCode settings to exclude new cache directories and enable workspace-wide Python analysis. --- .github/release-please/config.json | 22 ++++++----- .vscode/settings.json | 4 +- pyrightconfig.json | 3 ++ src/aic/cli/__init__.py | 27 ++++++-------- src/aic/cli/list_models.py | 10 ++--- src/aic/cli/main.py | 60 ++++++++++++++---------------- src/aic/commit_lint.py | 16 ++++---- src/aic/prompt/__init__.py | 33 ++++++++-------- 8 files changed, 87 insertions(+), 88 deletions(-) create mode 100644 pyrightconfig.json diff --git a/.github/release-please/config.json b/.github/release-please/config.json index 2820d174..7e1e25b4 100644 --- a/.github/release-please/config.json +++ b/.github/release-please/config.json @@ -1,17 +1,19 @@ { - "$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json", + "$schema": "https://github.com/googleapis/release-please/raw/main/schemas/config.json", "changelog-sections": [ { "type": "feat", "section": "✨ Features" }, { "type": "fix", "section": "🐛 Bug Fixes" }, - { "type": "perf", "section": "🚀 Performance Improvements" }, - { "type": "revert", "section": "🔙 Reverts" }, - { "type": "docs", "section": "📚 Documentation" }, - { "type": "style", "section": "🎨 Styles" }, - { "type": "chore", "section": "🏗 Miscellaneous Chores" }, - { "type": "refactor", "section": "♻️ Code Refactoring" }, - { "type": "test", "section": "🚦 Tests" }, - { "type": "build", "section": "📦 Build System" }, - { "type": "ci", "section": "💻 Continuous Integration" } + { "type": "perf", "section": "⚡️ Performance Improvements" }, + { "type": "revert", "section": "⏪ Reverts" }, + { "type": "docs", "section": "📝 Documentation" }, + { "type": "fix", "scope": "deps", "section": "⬆️ Dependencies" }, + { "type": "chore", "scope": "deps", "section": "⬆️ Dependencies" }, + { "type": "style", "section": "🎨 Styles", "hidden": true }, + { "type": "chore", "section": "🏗 Miscellaneous Chores", "hidden": true }, + { "type": "refactor", "section": "♻️ Code Refactoring", "hidden": true }, + { "type": "test", "section": "🚦 Tests", "hidden": true }, + { "type": "build", "section": "📦 Build System", "hidden": true }, + { "type": "ci", "section": "👷 Continuous Integration", "hidden": true } ], "packages": { ".": { diff --git a/.vscode/settings.json b/.vscode/settings.json index 85d78e72..b9217fc5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,8 @@ { "files.exclude": { "**/__pycache__": true, + "**/.ruff_cache": true, "**/.venv": true - } + }, + "python.analysis.diagnosticMode": "workspace" } diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..0102dcd9 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,3 @@ +{ + "typeCheckingMode": "strict" +} diff --git a/src/aic/cli/__init__.py b/src/aic/cli/__init__.py index 63801922..2c0ab95a 100644 --- a/src/aic/cli/__init__.py +++ b/src/aic/cli/__init__.py @@ -2,10 +2,9 @@ import typer -from aic import config as _config -from aic import log as _log -from aic.cli import list_models as _list_models -from aic.cli import main as _main +from aic import config, log +from aic.cli import list_models as cli_list +from aic.cli import main as cli_main app = typer.Typer(name="aic") @@ -21,24 +20,20 @@ def main( max_tokens: Annotated[Optional[int], typer.Option()] = None, # noqa: UP007 verify: Annotated[bool, typer.Option()] = True, ) -> None: - _log.init() + log.init() if list_models: - _list_models.list_models() + cli_list.list_models() return if pathspec is None: pathspec = [] pathspec += [":!*-lock.*", ":!*.lock*", ":!*cspell*"] - config: _config.Config = _config.load() + cfg: config.Config = config.load() if api_key is not None: - config.api_key = api_key + cfg.api_key = api_key if base_url is not None: - config.base_url = base_url + cfg.base_url = base_url if model is not None: - config.model = model + cfg.model = model if max_tokens is not None: - config.max_tokens = max_tokens - _main.main( - *pathspec, - config=config, - verify=verify, - ) + cfg.max_tokens = max_tokens + cli_main.main(*pathspec, cfg=cfg, verify=verify) diff --git a/src/aic/cli/list_models.py b/src/aic/cli/list_models.py index 28e0f680..7dca8fff 100644 --- a/src/aic/cli/list_models.py +++ b/src/aic/cli/list_models.py @@ -2,12 +2,12 @@ import rich from rich.table import Table -from aic import pretty as _pretty -from aic.api import openrouter as _openrouter +from aic import pretty +from aic.api import openrouter def list_models() -> None: - models: list[_openrouter.Model] = _openrouter.get_models() + models: list[openrouter.Model] = openrouter.get_models() table = Table(title="Models") table.add_column("ID", style="bright_cyan") table.add_column("Context", style="bright_magenta", justify="right") @@ -19,7 +19,7 @@ def list_models() -> None: table.add_row( model.id.removeprefix("openai/"), babel.numbers.format_number(model.context_length), - _pretty.format_currency(model.pricing.prompt * 1000), - _pretty.format_currency(model.pricing.completion * 1000), + pretty.format_currency(model.pricing.prompt * 1000), + pretty.format_currency(model.pricing.completion * 1000), ) rich.print(table) diff --git a/src/aic/cli/main.py b/src/aic/cli/main.py index de7aa301..b6a42288 100644 --- a/src/aic/cli/main.py +++ b/src/aic/cli/main.py @@ -6,32 +6,30 @@ from rich.markdown import Markdown from rich.panel import Panel -from aic import commit_lint as _lint -from aic import config as _config -from aic import git as _git -from aic import pretty as _pretty -from aic import prompt as _prompt -from aic import token as _token -from aic.api import openrouter as _openrouter +from aic import commit_lint as lint +from aic import config, git, pretty, prompt, token +from aic.api import openrouter if TYPE_CHECKING: - from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam + from openai.types import chat -def main(*pathspec: str, config: _config.Config, verify: bool) -> None: - _git.status(*pathspec) - diff: str = _git.diff(*pathspec) - model_info: _openrouter.Model = _openrouter.get_model(config.model) - client = openai.OpenAI(api_key=config.api_key, base_url=config.base_url) - prompt_builder = _prompt.Prompt() +def main(*pathspec: str, cfg: config.Config, verify: bool) -> None: + git.status(*pathspec) + diff: str = git.diff(*pathspec) + model_info: openrouter.Model = openrouter.get_model(cfg.model) + client = openai.OpenAI(api_key=cfg.api_key, base_url=cfg.base_url) + prompt_builder = prompt.Prompt() prompt_builder.ask() - prompt: str = prompt_builder.build(diff, model_info, config.max_tokens) - messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt}] - prompt_tokens: int = _token.num_tokens_from_messages(messages, config.model) - response: openai.Stream[ChatCompletionChunk] = client.chat.completions.create( + prompt_str: str = prompt_builder.build(diff, model_info, cfg.max_tokens) + messages: list[chat.ChatCompletionMessageParam] = [ + {"role": "user", "content": prompt_str} + ] + prompt_tokens: int = token.num_tokens_from_messages(messages, cfg.model) + response: openai.Stream[chat.ChatCompletionChunk] = client.chat.completions.create( messages=messages, - model=config.model, - max_tokens=config.max_tokens, + model=cfg.model, + max_tokens=cfg.max_tokens, stream=True, temperature=0.2, ) @@ -43,12 +41,10 @@ def main(*pathspec: str, config: _config.Config, verify: bool) -> None: completion += "\n" else: completion += content - completion_tokens: int = _token.num_tokens_from_string( - completion, config.model - ) + completion_tokens: int = token.num_tokens_from_string(completion, cfg.model) live.update( Group( - Panel(Markdown(_lint.sanitize(completion))), + Panel(Markdown(lint.sanitize(completion))), Panel( format_tokens(prompt_tokens, completion_tokens) + "\n" @@ -58,24 +54,24 @@ def main(*pathspec: str, config: _config.Config, verify: bool) -> None: ), ) ) - _git.commit(_lint.sanitize(completion), verify=verify) + git.commit(lint.sanitize(completion), verify=verify) def format_tokens(prompt_tokens: int, completion_tokens: int) -> str: total_tokens: int = prompt_tokens + completion_tokens - total: str = _pretty.format_int(total_tokens) - prompt: str = _pretty.format_int(prompt_tokens) - completion: str = _pretty.format_int(completion_tokens) + total: str = pretty.format_int(total_tokens) + prompt: str = pretty.format_int(prompt_tokens) + completion: str = pretty.format_int(completion_tokens) return f"Tokens: {total} = {prompt} (Prompt) + {completion} (Completion)" def format_cost( - prompt_tokens: int, completion_tokens: int, pricing: _openrouter.Model.Pricing + prompt_tokens: int, completion_tokens: int, pricing: openrouter.Model.Pricing ) -> str: prompt_cost: float = prompt_tokens * pricing.prompt completion_cost: float = completion_tokens * pricing.completion total_cost: float = prompt_cost + completion_cost - total: str = _pretty.format_currency(total_cost) - prompt: str = _pretty.format_currency(prompt_cost) - completion: str = _pretty.format_currency(completion_cost) + total: str = pretty.format_currency(total_cost) + prompt: str = pretty.format_currency(prompt_cost) + completion: str = pretty.format_currency(completion_cost) return f"Cost: {total} = {prompt} (Prompt) + {completion} (Completion)" diff --git a/src/aic/commit_lint.py b/src/aic/commit_lint.py index 9a56a65c..cdb22486 100644 --- a/src/aic/commit_lint.py +++ b/src/aic/commit_lint.py @@ -1,8 +1,4 @@ import re -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from collections.abc import Sequence PATTERN: re.Pattern[str] = re.compile( r"(?P\w+)(?:\((?P[^\)]+)\))?(?P!)?: (?P.+)" @@ -10,10 +6,14 @@ def sanitize(msg: str) -> str: - msg = msg.strip() - msg = msg.removeprefix("").removesuffix("") - msg = msg.strip() - lines: Sequence[str] = [sanitize_line(line) for line in msg.splitlines()] + while True: + msg_old: str = msg + msg = msg.removeprefix("```").removesuffix("```") + msg = msg.removeprefix("").removesuffix("") + msg = msg.strip() + if msg == msg_old: + break + lines: list[str] = [sanitize_line(line) for line in msg.splitlines()] return "\n".join(lines) diff --git a/src/aic/prompt/__init__.py b/src/aic/prompt/__init__.py index 0bb1a81f..d6607e70 100644 --- a/src/aic/prompt/__init__.py +++ b/src/aic/prompt/__init__.py @@ -1,16 +1,17 @@ import string -from typing import Any +from typing import TYPE_CHECKING, Any import questionary -import tiktoken from loguru import logger -from aic import token as _token -from aic.api import openrouter as _openrouter -from aic.prompt import _type -from aic.prompt import template as _template +from aic import token +from aic.api import openrouter +from aic.prompt import _type, template -TEMPLATE: string.Template = string.Template(_template.TEMPLATE) +if TYPE_CHECKING: + import tiktoken + +TEMPLATE: string.Template = string.Template(template.TEMPLATE) def _ask(question: questionary.Question) -> Any: @@ -65,7 +66,7 @@ def ask_breaking_change(self) -> str | None: self.breaking_change = None return self.breaking_change - def build(self, diff: str, model: _openrouter.Model, max_tokens: int) -> str: + def build(self, diff: str, model: openrouter.Model, maxtokens: int) -> str: _: Any prompt: str = TEMPLATE.substitute( { @@ -77,21 +78,21 @@ def build(self, diff: str, model: _openrouter.Model, max_tokens: int) -> str: ) model_id: str _, _, model_id = model.id.partition("/") - num_tokens: int = ( - _token.num_tokens_from_messages( + numtokens: int = ( + token.num_tokens_from_messages( [{"role": "user", "content": prompt}], model_id ) - + max_tokens + + maxtokens ) - if num_tokens > model.context_length: - encoding: tiktoken.Encoding = tiktoken.encoding_for_model(model_id) + if numtokens > model.context_length: + encoding: tiktoken.Encoding = token.encoding_for_model(model_id) tokens: list[int] = encoding.encode(diff) - origin_tokens: int = len(tokens) - tokens_truncated: list[int] = tokens[: model.context_length - num_tokens] + origintokens: int = len(tokens) + tokens_truncated: list[int] = tokens[: model.context_length - numtokens] diff_truncated: str = encoding.decode(tokens_truncated) logger.warning( "Truncated diff from {} to {} tokens", - origin_tokens, + origintokens, len(tokens_truncated), ) prompt = TEMPLATE.substitute(