-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
64 lines (56 loc) · 2.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import dotenv
dotenv.load_dotenv()
import argparse
from arena.arena_builder import ArenaBuilder
from contests.templates_factory import get_all_templates
from arena.job_queue import DuelsQueue
from arena.result_reporter import ChartReporter
from arena.storage import Storage
from utils.logger import Logger
class RivaLLMatch:
model_names = [
'gpt-4o-2024-08-06',
'gpt-4o-mini',
'gpt-3.5-turbo-0125',
'claude-3-opus-20240229',
'claude-3-5-sonnet-20240620',
'llama3-8b-8192',
'llama3-70b-8192',
# 'gemma-7b-it',
'mixtral-8x7b-32768',
'open-mixtral-8x22b-2404',
'gemini-1.0-pro-latest',
'gemini-1.5-pro-latest',
]
def __init__(self):
parser = argparse.ArgumentParser(description="Parse command line arguments for an experiment.")
parser.add_argument('--rounds', type=int, required=False,
help='Number of duel rounds between models (default 4).')
parser.add_argument('--experiment_id', type=str, required=True,
help="Experiment ID (required).")
parser.add_argument('--template_id', type=str,
help=f"""Experiment prompt's template ID (required for new experiment).
One of values: {[t.get_template_id() for t in get_all_templates()]}""")
parser.add_argument('--models', type=lambda s: s.split(','),
help="List of comma separated LLM model names (required for new experiment).")
self.args = parser.parse_args()
Logger.logger.append_file_logger(f"{self.args.experiment_id}.log")
self.db_path = f"./{self.args.experiment_id}.db"
self.storage = Storage(db_path=self.db_path)
self.duels_queue = DuelsQueue(db_path=self.db_path)
def run(self):
model_names = RivaLLMatch.model_names
n_llms = len(model_names)
n_rounds = self.args.rounds or 4
arena = (ArenaBuilder(n_rounds,
model_names,
self.args.template_id,
self.storage,
self.duels_queue)
.create())
competition_scores = arena.run(n_jobs=n_llms)
reporter = ChartReporter(self.args.template_id, model_names, competition_scores)
reporter.generate_reports()
if __name__ == '__main__':
rival_match = RivaLLMatch()
rival_match.run()