-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_jsons.py
66 lines (46 loc) · 2.05 KB
/
create_jsons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
import os
import json
import numpy as np
import sys
from sklearn.model_selection import train_test_split
def convert_MMAR(df, possibleLabelCount):
path = None
path = list(map(lambda x: x.split("/")[-1], df["Path"]))
label = list(df["Label"])
ret_list = []
for p, l in zip(path, label):
if l < possibleLabelCount:
temp_label = list([0] * possibleLabelCount)
temp_label[int(l)] = 1
temp_map = {"image":p, "label":temp_label}
else:
print("Skipping Image: Invalid Label")
ret_list.append(temp_map)
return ret_list
def convert_to_json(df, possibleLabelCount, save_csv=True):
train, valid = train_test_split(df, test_size=0.25, random_state=2021)
valid, test = train_test_split(valid, test_size=0.5, random_state=2021)
# print(train["Race"].value_counts(), valid["Race"].value_counts())
if save_csv:
train.to_csv("train.csv")
valid.to_csv("valid.csv")
test.to_csv("test.csv")
mmar_train = convert_MMAR(train, possibleLabelCount)
mmar_valid = convert_MMAR(valid, possibleLabelCount)
mmar_test = convert_MMAR(test, possibleLabelCount)
MMAR_train_valid = {"label_format":list([1] * possibleLabelCount), "training":mmar_train, "validation":mmar_valid}
MMAR_test = {"label_format":list([1] * possibleLabelCount), "validation":mmar_test}
json_obj = json.dumps(MMAR_train_valid, indent=4)
with open("train.json", 'w') as f:
json.dump(MMAR_train_valid, f, indent=4)
with open("test.json", 'w') as f:
json.dump(MMAR_test, f, indent=4)
if __name__ == "__main__":
if len(sys.argv) > 3:
raise ValueError("Malformed command line. Please specify the path to your data csv")
possibleLabelCount = 6
csv_path = sys.argv[1]
print('Converting CSV to JSON')
df = pd.read_csv(csv_path)
convert_to_json(df, possibleLabelCount)