-
Notifications
You must be signed in to change notification settings - Fork 2
/
augment.py
80 lines (55 loc) · 2.77 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3.6
import argparse
from typing import List
from pathlib import Path
from functools import partial
from tqdm import tqdm
from PIL import Image
from utils import augment, map_, mmap_, str2bool
def main(args: argparse.Namespace) -> None:
print(f'>>> Starting data augmentation (original + {args.n_aug} new images)')
root_dir: str = args.root_dir
dest_dir: str = args.dest_dir
folders: List[Path] = list(Path(root_dir).glob("*"))
dest_folders: List[Path] = [Path(dest_dir, p.name) for p in folders]
print(f"Will augment data from {len(folders)} folders ({map_(str, folders)})")
# Create all the destination folders
for d_folder in dest_folders:
d_folder.mkdir(parents=True, exist_ok=True)
names: List[str] = map_(lambda p: str(p.name), folders[0].glob("*.png"))
partial_process = partial(process_name, folders=folders, dest_folders=dest_folders, n_aug=args.n_aug, args=args)
mmap_(partial_process, names)
# for name in tqdm(names, ncols=75):
# partial_process(name)
def process_name(name: str, folders: List[Path], dest_folders: List[Path], n_aug: int, args) -> None:
images: List[Image.Image] = [Image.open(Path(folder, name)).convert('L') for folder in folders]
stem: str = Path(name).stem
# Save the unmodified images as _0
save(stem, 0, images, dest_folders)
for i in range(1, n_aug + 1):
augmented: List[Image.Image] = augment(*images,
rotate_angle=args.rotate_angle,
flip=args.flip,
mirror=args.mirror,
rotate=args.rotate,
scale=args.scale)
save(stem, i, augmented, dest_folders)
def save(stem: str, n: int, imgs: List[Image.Image], dest_folders: List[Path]):
assert len(imgs) == len(dest_folders)
for img, folder in zip(imgs, dest_folders):
img.save(Path(folder, f"{n}_{stem}").with_suffix(".png"))
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='Data augmentation parameters')
parser.add_argument('--root_dir', type=str, required=True)
parser.add_argument('--dest_dir', type=str, required=True)
parser.add_argument('--n_aug', type=int, required=True)
parser.add_argument('--flip', type=str2bool, default=True)
parser.add_argument('--mirror', type=str2bool, default=True)
parser.add_argument('--rotate', type=str2bool, default=True)
parser.add_argument('--scale', type=str2bool, default=False)
parser.add_argument('--rotate_angle', type=float, default=45)
args = parser.parse_args()
print(args)
return args
if __name__ == '__main__':
main(get_args())