-
Notifications
You must be signed in to change notification settings - Fork 4
/
util.py
52 lines (45 loc) · 1.5 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import random
import torch
import time
def list2tuple(l):
return tuple(list2tuple(x) if type(x)==list else x for x in l)
def tuple2list(t):
return list(tuple2list(x) if type(x)==tuple else x for x in t)
flatten=lambda l: sum(map(flatten, l),[]) if isinstance(l,tuple) else [l]
def parse_time():
return time.strftime("%Y.%m.%d-%H:%M:%S", time.localtime())
def set_global_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic=True
def eval_tuple(arg_return):
"""Evaluate a tuple string into a tuple."""
if type(arg_return) == tuple:
return arg_return
if arg_return[0] not in ["(", "["]:
arg_return = eval(arg_return)
else:
splitted = arg_return[1:-1].split(",")
List = []
for item in splitted:
try:
item = eval(item)
except:
pass
if item == "":
continue
List.append(item)
arg_return = tuple(List)
return arg_return
def flatten_query(queries):
'''
将不同类型的query添加到一个list中去,同时在每个query后面加上该query的结构类型
'''
all_queries = []
for query_structure in queries:
tmp_queries = list(queries[query_structure])
all_queries.extend([(query, query_structure) for query in tmp_queries])
return all_queries