Skip to content
This repository has been archived by the owner on May 28, 2024. It is now read-only.

Commit

Permalink
Lint (#41)
Browse files Browse the repository at this point in the history
Sets up automatic linting with pre-commit hooks. Also linted all the
files

---------

Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 authored May 25, 2023
1 parent 0bfe3fd commit 37b0730
Show file tree
Hide file tree
Showing 31 changed files with 390 additions and 307 deletions.
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ci:
autoupdate_schedule: monthly

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.270
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]

# Black needs to be ran after ruff with --fix
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,7 @@ If you want to help improve or extend the Aviary, please get in touch with us!
You can [reach us via email](mailto:[email protected]) for feedback and suggestions,
or [open an issue](https://github.com/ray-project/aviary/issues/new) on GitHub.
Pull requests are also welcome!
We use `pre-commit` hooks to ensure that all code is formatted correctly.
Make sure to `pip install pre-commit` and then run `pre-commit install`.
You can also run `./format` to run the hooks manually.
136 changes: 61 additions & 75 deletions aviary/api/cli.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
from aviary.api import sdk

import ast
import json
from typing import Annotated, List

import typer
from rich import print as rp
from rich.console import Console
from rich.table import Table
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.table import Table

from aviary.api import sdk

app = typer.Typer()

model_type = typer.Option(
default=...,
help="The model to use. You can specify multiple models."
)
prompt_type = typer.Option(help='Prompt to query')
stats_type = typer.Option(
help='Whether to print generated statistics'
default=..., help="The model to use. You can specify multiple models."
)
prompt_type = typer.Option(help="Prompt to query")
stats_type = typer.Option(help="Whether to print generated statistics")


@app.command(name="list_models")
def list_models():
""" Get a list of the available models
"""
"""Get a list of the available models"""
result = sdk.models()
print('\n'.join(result))
print("\n".join(result))


def _print_result(result, model, print_stats):
rp(f'[bold]{model}:[/]')
rp(result['generated_text'])
rp(f"[bold]{model}:[/]")
rp(result["generated_text"])
if print_stats:
del result['generated_text']
rp('[bold]Stats:[/]')
del result["generated_text"]
rp("[bold]Stats:[/]")
rp(result)


Expand All @@ -48,11 +44,12 @@ def progress_spinner():


@app.command()
def query(model: Annotated[List[str], model_type],
prompt: Annotated[str, prompt_type],
print_stats: Annotated[bool, stats_type] = False):
""" Query one or several models with a prompt.
"""
def query(
model: Annotated[List[str], model_type],
prompt: Annotated[str, prompt_type],
print_stats: Annotated[bool, stats_type] = False,
):
"""Query one or several models with a prompt."""
with progress_spinner() as progress:
for m in model:
progress.add_task(
Expand All @@ -63,11 +60,12 @@ def query(model: Annotated[List[str], model_type],


@app.command(name="batch_query")
def batch_query(model: Annotated[List[str], model_type],
prompt: Annotated[List[str], prompt_type],
print_stats: Annotated[bool, stats_type]):
"""Query a model with a batch of prompts.
"""
def batch_query(
model: Annotated[List[str], model_type],
prompt: Annotated[List[str], prompt_type],
print_stats: Annotated[bool, stats_type],
):
"""Query a model with a batch of prompts."""
with progress_spinner() as progress:
for m in model:
progress.add_task(
Expand All @@ -89,100 +87,88 @@ def run(model: Annotated[List[str], model_type]):


prompt_file_type = typer.Option(
default=...,
help='File containing prompts. A simple text file'
default=..., help="File containing prompts. A simple text file"
)
separator_type = typer.Option(help='Separator used in prompt files')
results_type = typer.Option(help='Where to save the results')
separator_type = typer.Option(help="Separator used in prompt files")
results_type = typer.Option(help="Where to save the results")


@app.command(name="multi_query")
def multi_query(model: Annotated[List[str], model_type],
prompt_file: Annotated[str, prompt_file_type],
separator: Annotated[str, separator_type] = '----',
output_file: Annotated[str, results_type] = 'aviary-output.json'):
"""Query one or multiple models with a batch of prompts taken from a file.
"""
def multi_query(
model: Annotated[List[str], model_type],
prompt_file: Annotated[str, prompt_file_type],
separator: Annotated[str, separator_type] = "----",
output_file: Annotated[str, results_type] = "aviary-output.json",
):
"""Query one or multiple models with a batch of prompts taken from a file."""
# TODO: batch the requests once the endpoint is working

with progress_spinner() as progress:
progress.add_task(
description=f"Loading your prompts from {prompt_file}.",
total=None
description=f"Loading your prompts from {prompt_file}.", total=None
)
with open(prompt_file, 'r') as f:
with open(prompt_file, "r") as f:
prompts = f.read().split(separator)
results = {}

for prompt in prompts:
progress.add_task(
description=f"Processing all models against prompt: {prompt}.",
total=None
total=None,
)
results[prompt] = []
for m in model:
result = sdk.query(m, prompt)
text = result['generated_text']
del result['generated_text']
results[prompt].append({
'model': m,
'result': text,
'stats': result
})
text = result["generated_text"]
del result["generated_text"]
results[prompt].append({"model": m, "result": text, "stats": result})

progress.add_task(
description=f"Writing output file.",
total=None
)
with open(output_file, 'w') as f:
progress.add_task(description="Writing output file.", total=None)
with open(output_file, "w") as f:
f.write(json.dumps(results, indent=2))


evaluator_type = typer.Option(help='Which LLM to use for evaluation')
evaluator_type = typer.Option(help="Which LLM to use for evaluation")


@app.command()
def evaluate(input_file: Annotated[str, results_type] = 'aviary-output.json',
evaluation_file: Annotated[str, results_type] = 'evaluation-output.json',
evaluator: Annotated[str, evaluator_type] = 'gpt-4'):
def evaluate(
input_file: Annotated[str, results_type] = "aviary-output.json",
evaluation_file: Annotated[str, results_type] = "evaluation-output.json",
evaluator: Annotated[str, evaluator_type] = "gpt-4",
):
"""Evaluate and summarize the results of a multi_query run with a strong
'evaluator' LLM like GPT-4.
The results of the ranking are stored to file and displayed in a table.
"""
with progress_spinner() as progress:
progress.add_task(
description=f"Loading the evaluator LLM.",
total=None
)
if evaluator == 'gpt-4':
progress.add_task(description="Loading the evaluator LLM.", total=None)
if evaluator == "gpt-4":
from aviary.common.evaluation import GPT

eval_model = GPT()
else:
raise NotImplementedError(f'No evaluator for {evaluator}')
raise NotImplementedError(f"No evaluator for {evaluator}")

with open(input_file, 'r') as f:
with open(input_file, "r") as f:
results = json.load(f)

for prompt, result_list in results.items():
progress.add_task(
description=f"Evaluating results for prompt: {prompt}.",
total=None
description=f"Evaluating results for prompt: {prompt}.", total=None
)
evaluation = eval_model.evaluate_results(prompt, result_list)
try:
# GPT-4 returns a string with a Python dictionary, hopefully!
evaluation = ast.literal_eval(evaluation)
except:
print(f'Could not parse evaluation: {evaluation}')
except Exception:
print(f"Could not parse evaluation: {evaluation}")

for i, res in enumerate(results[prompt]):
results[prompt][i]["rank"] = evaluation[i]['rank']
for i, _res in enumerate(results[prompt]):
results[prompt][i]["rank"] = evaluation[i]["rank"]

progress.add_task(
description=f"Storing evaluations.",
total=None
)
with open(evaluation_file, 'w') as f:
progress.add_task(description="Storing evaluations.", total=None)
with open(evaluation_file, "w") as f:
f.write(json.dumps(results, indent=2))

for prompt in results.keys():
Expand All @@ -192,7 +178,7 @@ def evaluate(input_file: Annotated[str, results_type] = 'aviary-output.json',
table.add_column("Rank", style="magenta")
table.add_column("Response", justify="right", style="green")

for i, res in enumerate(results[prompt]):
for i, _res in enumerate(results[prompt]):
model = results[prompt][i]["model"]
response = results[prompt][i]["result"]
rank = results[prompt][i]["rank"]
Expand Down
4 changes: 2 additions & 2 deletions aviary/api/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def has_ray():
try:
import ray
import ray # noqa: F401

return True
except ImportError:
Expand All @@ -9,7 +9,7 @@ def has_ray():

def has_backend():
try:
import aviary.backend
import aviary.backend # noqa: F401

return True
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions aviary/backend/llm/initializers/_llama_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

logger = get_logger(__name__)

# ruff: noqa: B006, B905


# TODO Upstream this
# Llama with added min_tokens parameter
Expand Down
3 changes: 2 additions & 1 deletion aviary/backend/llm/pipelines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def _sanitize_parameters(
if len(stop_sequence_ids) > 1:
warnings.warn(
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
" the stop sequence will be used as the stop sequence string in the interim."
" the stop sequence will be used as the stop sequence string in the interim.",
stacklevel=2,
)
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]

Expand Down
2 changes: 1 addition & 1 deletion aviary/backend/llm/pipelines/default_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
decoded: List[Response] = []
num_generated_tokens_batch = 0
num_input_tokens_batch = 0
for token_unwrapped, inputs_unwrapped in zip(tokens, input_ids):
for token_unwrapped, inputs_unwrapped in zip((tokens, input_ids), strict=True):
logger.info(
f"Unprocessed generated tokens: '{self.tokenizer.decode(token_unwrapped, skip_special_tokens=False).encode('unicode_escape').decode('utf-8')}'"
)
Expand Down
8 changes: 5 additions & 3 deletions aviary/backend/llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def init_model(
assert len(resp2) == 1
assert all(x.generated_text for x in resp2)
warmup_success = True
except torch.cuda.OutOfMemoryError as e:
except torch.cuda.OutOfMemoryError:
batch_size -= 2
logger.warning(
f"Warmup failed due to CUDA OOM, reducing batch size to {batch_size}"
Expand Down Expand Up @@ -213,7 +213,7 @@ def init_worker_group(self, scaling_config: ScalingConfig) -> None:
self._initializing = True
try:
self._init_worker_group(scaling_config)
except Exception as e:
except Exception:
self._initializing = False
raise
self._initializing = False
Expand Down Expand Up @@ -271,7 +271,9 @@ def _init_worker_group(self, scaling_config: ScalingConfig) -> None:
local_rank,
num_cpus_per_worker=scaling_config.num_cpus_per_worker,
)
for worker, local_rank in zip(self.prediction_workers, local_ranks)
for worker, local_rank in zip(
(self.prediction_workers, local_ranks), strict=True
)
]
)

Expand Down
2 changes: 1 addition & 1 deletion aviary/backend/server/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _parse_path_args(path: str):
return [LLMApp.parse_yaml(f)]
elif os.path.isdir(path):
apps = []
for root, dirs, files in os.walk(path):
for root, _dirs, files in os.walk(path):
for p in files:
if _is_yaml_file(p):
with open(os.path.join(root, p), "r") as f:
Expand Down
Loading

0 comments on commit 37b0730

Please sign in to comment.