-
Notifications
You must be signed in to change notification settings - Fork 0
/
cli.py
139 lines (118 loc) · 4.88 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import logging
import json
from tqdm import tqdm
from argparse import ArgumentParser
from model import YOLO
from label_studio_sdk.client import LabelStudio
from label_studio_ml.response import ModelResponse
LABEL_STUDIO_URL = os.getenv("LABEL_STUDIO_URL", "http://localhost:8080")
LABEL_STUDIO_API_KEY = os.getenv("LABEL_STUDIO_API_KEY", "your_api_key")
PROJECT_ID = os.getenv("LABEL_STUDIO_PROJECT_ID", "1")
logger = logging.getLogger(__name__)
def arg_parser():
parser = ArgumentParser(description="YOLO client for Label Studio ML Backend")
parser.add_argument(
"--ls-url", type=str, default=LABEL_STUDIO_URL, help="Label Studio URL"
)
parser.add_argument(
"--ls-api-key",
type=str,
default=LABEL_STUDIO_API_KEY,
help="Label Studio API Key",
)
parser.add_argument(
"--project", type=str, default="1", help="Label Studio Project ID"
)
parser.add_argument(
"--tasks",
type=str,
default="tasks.json",
help="Path to tasks JSON file with list of ids or task datas. Example: tasks.json\n"
"String with ids separated by comma: if you provide task ids, "
"task data will be downloaded automatically from the Label Studio instance. Example: 1,2,3",
)
return parser.parse_args()
class LabelStudioMLPredictor:
def __init__(self, ls_url, ls_api_key):
self.ls = LabelStudio(base_url=ls_url, api_key=ls_api_key)
logger.info(f"Successfully connected to Label Studio: {ls_url}")
def run(self, project, tasks):
# initialize Label Studio SDK client
ls = self.ls
project = ls.projects.get(id=project)
logger.info(f"Project is retrieved: {project.id}")
tasks = self.prepare_tasks(ls, tasks)
# load YOLO model
# TODO: use get_all_classes_inherited_LabelStudioMLBase to detect model classes
model = YOLO(project_id=project.id, label_config=project.label_config)
logger.info(f"YOLO ML backend is created")
# predict and send prediction to Label Studio
for task in tqdm(tasks, desc="Predict tasks"):
response = model.predict([task])
predictions = self.postprocess_response(model, response, task)
# send predictions to Label Studio
for prediction in predictions:
ls.predictions.create(
task=task["id"],
score=prediction.get("score", 0),
model_version=prediction.get("model_version", "none"),
result=prediction["result"],
)
logger.info("Model predictions are done!")
@staticmethod
def postprocess_response(model, response, task):
if response is None:
logger.warning(f"No predictions for task: {task}")
return None
# model returned ModelResponse
if isinstance(response, ModelResponse):
# check model version
if not response.has_model_version():
if model.model_version:
response.set_version(str(model.model_version))
else:
response.update_predictions_version()
response = response.model_dump()
predictions = response.get("predictions")
# model returned list of dicts with predictions (old format)
elif isinstance(response, list):
predictions = response
else:
logger.error("No predictions generated by model")
return None
return predictions
@staticmethod
def prepare_tasks(ls, tasks):
# get tasks
if os.path.exists(tasks):
with open(tasks) as f:
tasks = json.load(f)
else:
tasks = tasks.split(",")
tasks = [int(task) for task in tasks]
assert isinstance(tasks, list), "Tasks should be a list"
assert len(tasks) > 0, "'Task list can't be empty"
logger.info(f"Detected {len(tasks)} tasks")
# check task data
if isinstance(tasks[0], dict):
if "data" not in tasks[0] or "id" not in tasks[0]:
raise ValueError("'data' and 'id' must be presented in all tasks")
elif isinstance(tasks[0], int):
# load tasks from Label Studio instance using SDK
logger.info("Task loading from Label Studio instance ...")
tasks = [
{"id": task_id, "data": ls.tasks.get(task_id).data}
for task_id in tqdm(tasks)
]
logger.info("Task loading finished")
else:
raise ValueError(
"Unknown task format: "
"tasks should be a list of dicts (task data) or a list of task ids"
)
return tasks
if __name__ == "__main__":
args = arg_parser()
predictor = LabelStudioMLPredictor(args.ls_url, args.ls_api_key)
predictor.run(args.project, args.tasks)