forked from pytorch/benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen_summary_metadata.py
222 lines (194 loc) · 7.17 KB
/
gen_summary_metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
A Benchmark Summary Metadata tool to extract and generate metadata from models at runtime.
"""
import argparse
import os
from copy import deepcopy
from typing import Any, Dict, List, Tuple
import torch
import yaml
from torchbenchmark import (
_list_model_paths,
list_models,
load_model_by_name,
ModelDetails,
ModelTask,
str_to_bool,
)
TIMEOUT = 300 # seconds
torchbench_dir = "torchbenchmark"
model_dir = "models"
_DEFAULT_METADATA_ = {
"train_benchmark": True,
"train_deterministic": False,
"eval_benchmark": True,
"eval_deterministic": False,
"eval_nograd": True,
# 'origin': None,
# 'train_dtype': 'float32',
# 'eval_dtype': 'float32',
}
def _parser_helper(input):
return None if input is None else str_to_bool(str(input))
def _process_model_details_to_metadata(
train_detail: ModelDetails, eval_detail: ModelDetails
) -> Dict[str, Any]:
metadata = {}
for k, v in _DEFAULT_METADATA_.items():
if hasattr(train_detail, k):
metadata[k] = getattr(train_detail, k)
elif train_detail and k in train_detail.metadata:
metadata[k] = train_detail.metadata[k]
elif eval_detail and k in eval_detail.metadata:
metadata[k] = eval_detail.metadata[k]
else:
metadata[k] = v
return metadata
def _extract_detail(path: str) -> Dict[str, Any]:
name = os.path.basename(path)
device = "cuda"
t_detail = None
e_detail = None
# Separate train and eval to isolated processes.
task_t = ModelTask(name, timeout=TIMEOUT)
try:
task_t.make_model_instance(device=device)
task_t.set_train()
task_t.train()
task_t.extract_details_train()
task_t.del_model_instance()
t_detail = deepcopy(task_t._details)
except NotImplementedError:
print(f"Model {name} train is not fully implemented. skipping...")
del task_t
task_e = ModelTask(name, timeout=TIMEOUT)
try:
task_e.make_model_instance(device=device)
task_e.set_eval()
task_e.eval()
task_e.extract_details_eval()
task_e.del_model_instance()
e_detail = deepcopy(task_e._details)
except NotImplementedError:
print(f"Model {name} eval is not fully implemented. skipping...")
del task_e
return _process_model_details_to_metadata(t_detail, e_detail)
def _extract_all_details(model_names: List[str]) -> List[Tuple[str, Dict[str, Any]]]:
details = []
for model_path in _list_model_paths():
model_name = os.path.basename(model_path)
if model_name not in model_names:
continue
ed = _extract_detail(model_path)
details.append((model_path, ed))
return details
def _print_extracted_details(extracted_details: List[Tuple[str, Dict[str, Any]]]):
for path, ex_detail in extracted_details:
name = os.path.basename(path)
print(f"Model: {name} , Details: {ex_detail}")
def _maybe_override_extracted_details(
args, extracted_details: List[Tuple[str, Dict[str, Any]]]
):
for _path, ex_detail in extracted_details:
if args.train_benchmark is not None:
ex_detail["train_benchmark"] = args.train_benchmark
elif args.train_deterministic is not None:
ex_detail["train_deterministic"] = args.train_deterministic
elif args.eval_benchmark is not None:
ex_detail["eval_benchmark"] = args.eval_benchmark
elif args.eval_deterministic is not None:
ex_detail["eval_deterministic"] = args.eval_deterministic
elif args.eval_nograd is not None:
ex_detail["eval_nograd"] = args.eval_nograd
def _write_metadata_yaml_files(extracted_details: List[Tuple[str, Dict[str, Any]]]):
for path, ex_detail in extracted_details:
metadata_path = path + "/metadata.yaml"
with open(metadata_path, "w") as file:
yaml.dump(ex_detail, file)
print(f"Processed file: {metadata_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
"--model",
default=None,
help="Full name of a model to update. If absent, applies to all models.",
)
parser.add_argument(
"--extract-only",
default=False,
action="store_true",
help="Only extract model details.",
)
parser.add_argument(
"--train-benchmark",
default=None,
type=_parser_helper,
help="Whether to enable PyTorch benchmark mode during train.",
)
parser.add_argument(
"--train-deterministic",
default=None,
type=_parser_helper,
help="Whether to enable deterministic during train.",
)
parser.add_argument(
"--eval-benchmark",
default=None,
type=_parser_helper,
help="Whether to enable PyTorch benchmark mode during eval.",
)
parser.add_argument(
"--eval-deterministic",
default=None,
type=_parser_helper,
help="Whether to enable deterministic during eval.",
)
parser.add_argument(
"--eval-nograd",
default=None,
type=_parser_helper,
help="Whether to enable no_grad during eval.",
)
# parser.add_argument("--origin", default=None,
# help="Location of benchmark's origin. Such as torchaudio or torchvision.")
# parser.add_argument("--train-dtype", default=None,
# choices=['float32', 'float16', 'bfloat16', 'amp'], help="Which fp type to perform training.")
# parser.add_argument("--eval-dtype", default=None,
# choices=['float32', 'float16', 'bfloat16', 'amp'], help="Which fp type to perform eval.")
args = parser.parse_args()
# Only allow this script for cuda for now.
if not torch.cuda.is_available():
print(
"This tool is currently only supported when the system has a cuda device."
)
exit(1)
# Find the matching model, or use all models.
models = []
model_names = []
if args.model is not None:
Model = load_model_by_name(args.model)
if not Model:
print(f"Unable to find model matching: {args.model}.")
exit(-1)
models.append(Model)
model_names.append(Model.name)
print(f"Generating metadata to select model: {model_names}.")
else:
models.extend(list_models(model_match=args.model))
model_names.extend([m.name for m in models])
print("Generating metadata to all models.")
# Extract all model details from models.
extracted_details = _extract_all_details(model_names)
print("Printing extracted metadata.")
_print_extracted_details(extracted_details)
# Stop here for extract-only.
if args.extract_only:
print("--extract-only is set. Stop here.")
exit(0)
# Apply details passed in by flags.
_maybe_override_extracted_details(args, extracted_details)
print("Printing metadata after applying any modifications.")
_print_extracted_details(extracted_details)
# TODO: Modify and update the model to apply metadata changes by the user.
# Generate metadata files for each matching models.
_write_metadata_yaml_files(extracted_details)