Skip to content

Commit

Permalink
fix: update release-please config and refactor CLI imports
Browse files Browse the repository at this point in the history
- Updated release-please configuration to use more descriptive emojis and hide less relevant sections.
- Refactored CLI imports for better readability and maintainability.
- Added new file for strict type checking configuration in Pyright.
- Updated VSCode settings to exclude new cache directories and enable workspace-wide Python analysis.
  • Loading branch information
liblaf committed Jun 15, 2024
1 parent 2f6e744 commit 4d0d878
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 88 deletions.
22 changes: 12 additions & 10 deletions .github/release-please/config.json
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
{
"$schema": "https://raw.githubusercontent.com/googleapis/release-please/main/schemas/config.json",
"$schema": "https://github.com/googleapis/release-please/raw/main/schemas/config.json",
"changelog-sections": [
{ "type": "feat", "section": "✨ Features" },
{ "type": "fix", "section": "🐛 Bug Fixes" },
{ "type": "perf", "section": "🚀 Performance Improvements" },
{ "type": "revert", "section": "🔙 Reverts" },
{ "type": "docs", "section": "📚 Documentation" },
{ "type": "style", "section": "🎨 Styles" },
{ "type": "chore", "section": "🏗 Miscellaneous Chores" },
{ "type": "refactor", "section": "♻️ Code Refactoring" },
{ "type": "test", "section": "🚦 Tests" },
{ "type": "build", "section": "📦 Build System" },
{ "type": "ci", "section": "💻 Continuous Integration" }
{ "type": "perf", "section": "⚡️ Performance Improvements" },
{ "type": "revert", "section": "⏪ Reverts" },
{ "type": "docs", "section": "📝 Documentation" },
{ "type": "fix", "scope": "deps", "section": "⬆️ Dependencies" },
{ "type": "chore", "scope": "deps", "section": "⬆️ Dependencies" },
{ "type": "style", "section": "🎨 Styles", "hidden": true },
{ "type": "chore", "section": "🏗 Miscellaneous Chores", "hidden": true },
{ "type": "refactor", "section": "♻️ Code Refactoring", "hidden": true },
{ "type": "test", "section": "🚦 Tests", "hidden": true },
{ "type": "build", "section": "📦 Build System", "hidden": true },
{ "type": "ci", "section": "👷 Continuous Integration", "hidden": true }
],
"packages": {
".": {
Expand Down
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"files.exclude": {
"**/__pycache__": true,
"**/.ruff_cache": true,
"**/.venv": true
}
},
"python.analysis.diagnosticMode": "workspace"
}
3 changes: 3 additions & 0 deletions pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"typeCheckingMode": "strict"
}
27 changes: 11 additions & 16 deletions src/aic/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import typer

from aic import config as _config
from aic import log as _log
from aic.cli import list_models as _list_models
from aic.cli import main as _main
from aic import config, log
from aic.cli import list_models as cli_list
from aic.cli import main as cli_main

app = typer.Typer(name="aic")

Expand All @@ -21,24 +20,20 @@ def main(
max_tokens: Annotated[Optional[int], typer.Option()] = None, # noqa: UP007
verify: Annotated[bool, typer.Option()] = True,
) -> None:
_log.init()
log.init()
if list_models:
_list_models.list_models()
cli_list.list_models()
return
if pathspec is None:
pathspec = []
pathspec += [":!*-lock.*", ":!*.lock*", ":!*cspell*"]
config: _config.Config = _config.load()
cfg: config.Config = config.load()
if api_key is not None:
config.api_key = api_key
cfg.api_key = api_key
if base_url is not None:
config.base_url = base_url
cfg.base_url = base_url
if model is not None:
config.model = model
cfg.model = model
if max_tokens is not None:
config.max_tokens = max_tokens
_main.main(
*pathspec,
config=config,
verify=verify,
)
cfg.max_tokens = max_tokens
cli_main.main(*pathspec, cfg=cfg, verify=verify)
10 changes: 5 additions & 5 deletions src/aic/cli/list_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import rich
from rich.table import Table

from aic import pretty as _pretty
from aic.api import openrouter as _openrouter
from aic import pretty
from aic.api import openrouter


def list_models() -> None:
models: list[_openrouter.Model] = _openrouter.get_models()
models: list[openrouter.Model] = openrouter.get_models()
table = Table(title="Models")
table.add_column("ID", style="bright_cyan")
table.add_column("Context", style="bright_magenta", justify="right")
Expand All @@ -19,7 +19,7 @@ def list_models() -> None:
table.add_row(
model.id.removeprefix("openai/"),
babel.numbers.format_number(model.context_length),
_pretty.format_currency(model.pricing.prompt * 1000),
_pretty.format_currency(model.pricing.completion * 1000),
pretty.format_currency(model.pricing.prompt * 1000),
pretty.format_currency(model.pricing.completion * 1000),
)
rich.print(table)
60 changes: 28 additions & 32 deletions src/aic/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,30 @@
from rich.markdown import Markdown
from rich.panel import Panel

from aic import commit_lint as _lint
from aic import config as _config
from aic import git as _git
from aic import pretty as _pretty
from aic import prompt as _prompt
from aic import token as _token
from aic.api import openrouter as _openrouter
from aic import commit_lint as lint
from aic import config, git, pretty, prompt, token
from aic.api import openrouter

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from openai.types import chat


def main(*pathspec: str, config: _config.Config, verify: bool) -> None:
_git.status(*pathspec)
diff: str = _git.diff(*pathspec)
model_info: _openrouter.Model = _openrouter.get_model(config.model)
client = openai.OpenAI(api_key=config.api_key, base_url=config.base_url)
prompt_builder = _prompt.Prompt()
def main(*pathspec: str, cfg: config.Config, verify: bool) -> None:
git.status(*pathspec)
diff: str = git.diff(*pathspec)
model_info: openrouter.Model = openrouter.get_model(cfg.model)
client = openai.OpenAI(api_key=cfg.api_key, base_url=cfg.base_url)
prompt_builder = prompt.Prompt()
prompt_builder.ask()
prompt: str = prompt_builder.build(diff, model_info, config.max_tokens)
messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt}]
prompt_tokens: int = _token.num_tokens_from_messages(messages, config.model)
response: openai.Stream[ChatCompletionChunk] = client.chat.completions.create(
prompt_str: str = prompt_builder.build(diff, model_info, cfg.max_tokens)
messages: list[chat.ChatCompletionMessageParam] = [
{"role": "user", "content": prompt_str}
]
prompt_tokens: int = token.num_tokens_from_messages(messages, cfg.model)
response: openai.Stream[chat.ChatCompletionChunk] = client.chat.completions.create(
messages=messages,
model=config.model,
max_tokens=config.max_tokens,
model=cfg.model,
max_tokens=cfg.max_tokens,
stream=True,
temperature=0.2,
)
Expand All @@ -43,12 +41,10 @@ def main(*pathspec: str, config: _config.Config, verify: bool) -> None:
completion += "\n"
else:
completion += content
completion_tokens: int = _token.num_tokens_from_string(
completion, config.model
)
completion_tokens: int = token.num_tokens_from_string(completion, cfg.model)
live.update(
Group(
Panel(Markdown(_lint.sanitize(completion))),
Panel(Markdown(lint.sanitize(completion))),
Panel(
format_tokens(prompt_tokens, completion_tokens)
+ "\n"
Expand All @@ -58,24 +54,24 @@ def main(*pathspec: str, config: _config.Config, verify: bool) -> None:
),
)
)
_git.commit(_lint.sanitize(completion), verify=verify)
git.commit(lint.sanitize(completion), verify=verify)


def format_tokens(prompt_tokens: int, completion_tokens: int) -> str:
total_tokens: int = prompt_tokens + completion_tokens
total: str = _pretty.format_int(total_tokens)
prompt: str = _pretty.format_int(prompt_tokens)
completion: str = _pretty.format_int(completion_tokens)
total: str = pretty.format_int(total_tokens)
prompt: str = pretty.format_int(prompt_tokens)
completion: str = pretty.format_int(completion_tokens)
return f"Tokens: {total} = {prompt} (Prompt) + {completion} (Completion)"


def format_cost(
prompt_tokens: int, completion_tokens: int, pricing: _openrouter.Model.Pricing
prompt_tokens: int, completion_tokens: int, pricing: openrouter.Model.Pricing
) -> str:
prompt_cost: float = prompt_tokens * pricing.prompt
completion_cost: float = completion_tokens * pricing.completion
total_cost: float = prompt_cost + completion_cost
total: str = _pretty.format_currency(total_cost)
prompt: str = _pretty.format_currency(prompt_cost)
completion: str = _pretty.format_currency(completion_cost)
total: str = pretty.format_currency(total_cost)
prompt: str = pretty.format_currency(prompt_cost)
completion: str = pretty.format_currency(completion_cost)
return f"Cost: {total} = {prompt} (Prompt) + {completion} (Completion)"
16 changes: 8 additions & 8 deletions src/aic/commit_lint.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Sequence

PATTERN: re.Pattern[str] = re.compile(
r"(?P<type>\w+)(?:\((?P<scope>[^\)]+)\))?(?P<breaking>!)?: (?P<description>.+)"
)


def sanitize(msg: str) -> str:
msg = msg.strip()
msg = msg.removeprefix("<Answer>").removesuffix("</Answer>")
msg = msg.strip()
lines: Sequence[str] = [sanitize_line(line) for line in msg.splitlines()]
while True:
msg_old: str = msg
msg = msg.removeprefix("```").removesuffix("```")
msg = msg.removeprefix("<Answer>").removesuffix("</Answer>")
msg = msg.strip()
if msg == msg_old:
break
lines: list[str] = [sanitize_line(line) for line in msg.splitlines()]
return "\n".join(lines)


Expand Down
33 changes: 17 additions & 16 deletions src/aic/prompt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import string
from typing import Any
from typing import TYPE_CHECKING, Any

import questionary
import tiktoken
from loguru import logger

from aic import token as _token
from aic.api import openrouter as _openrouter
from aic.prompt import _type
from aic.prompt import template as _template
from aic import token
from aic.api import openrouter
from aic.prompt import _type, template

TEMPLATE: string.Template = string.Template(_template.TEMPLATE)
if TYPE_CHECKING:
import tiktoken

TEMPLATE: string.Template = string.Template(template.TEMPLATE)


def _ask(question: questionary.Question) -> Any:
Expand Down Expand Up @@ -65,7 +66,7 @@ def ask_breaking_change(self) -> str | None:
self.breaking_change = None
return self.breaking_change

def build(self, diff: str, model: _openrouter.Model, max_tokens: int) -> str:
def build(self, diff: str, model: openrouter.Model, maxtokens: int) -> str:
_: Any
prompt: str = TEMPLATE.substitute(
{
Expand All @@ -77,21 +78,21 @@ def build(self, diff: str, model: _openrouter.Model, max_tokens: int) -> str:
)
model_id: str
_, _, model_id = model.id.partition("/")
num_tokens: int = (
_token.num_tokens_from_messages(
numtokens: int = (
token.num_tokens_from_messages(
[{"role": "user", "content": prompt}], model_id
)
+ max_tokens
+ maxtokens
)
if num_tokens > model.context_length:
encoding: tiktoken.Encoding = tiktoken.encoding_for_model(model_id)
if numtokens > model.context_length:
encoding: tiktoken.Encoding = token.encoding_for_model(model_id)
tokens: list[int] = encoding.encode(diff)
origin_tokens: int = len(tokens)
tokens_truncated: list[int] = tokens[: model.context_length - num_tokens]
origintokens: int = len(tokens)
tokens_truncated: list[int] = tokens[: model.context_length - numtokens]
diff_truncated: str = encoding.decode(tokens_truncated)
logger.warning(
"Truncated diff from {} to {} tokens",
origin_tokens,
origintokens,
len(tokens_truncated),
)
prompt = TEMPLATE.substitute(
Expand Down

0 comments on commit 4d0d878

Please sign in to comment.