-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
89 lines (70 loc) · 4.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import random
import torch
import os
import json
import time
def replicability(seed=None):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed) # Sets the seed for generating random numbers. Returns a torch.Generator object.
torch.cuda.manual_seed(seed) # Sets the seed for generating random numbers for the current GPU. It’s safe to call this function if CUDA is not available; in that case, it is silently ignored.insufficient to get determinism
torch.cuda.manual_seed_all(seed) # Sets the seed for generating random numbers on all GPUs.
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
def rounder(num, places=2):
num=num*100
return round(num, places)
def data_split(dataset):
if dataset=="cast-19":
# We split cast-19 following https://github.com/thunlp/ConvDR/blob/main/data/gen_ranking_data.py
foldid2qid={1: ['31_1', '31_2', '31_3', '31_4', '31_5', '31_6', '31_7', '31_8', '31_9', '32_1', '32_10', '32_11', '32_2',
'32_3', '32_4', '32_5', '32_6', '32_7', '32_8', '32_9', '33_1', '33_2', '33_3', '33_4', '33_5', '33_6',
'33_7',
'33_8', '34_1', '34_2', '34_3', '34_4', '34_5', '34_6', '34_7', '34_8'],
2: ['37_1', '37_2', '37_3', '37_4', '37_5', '37_6', '37_7', '37_8', '40_1', '40_2', '40_3', '40_4', '40_5',
'40_6',
'40_7', '40_8', '49_1', '49_2', '49_3', '49_4', '49_5', '49_6', '49_7', '49_8', '50_1', '50_2', '50_3',
'50_4',
'50_5', '50_6', '50_7', '50_8'],
3: ['54_1', '54_2', '54_3', '54_4', '54_5', '54_6', '54_7', '54_8', '54_9', '56_1', '56_2', '56_3', '56_4',
'56_5',
'56_6', '56_7', '56_8', '58_1', '58_2', '58_3', '58_4', '58_5', '58_6', '58_7', '58_8', '59_1', '59_2',
'59_3',
'59_4', '59_5', '59_6', '59_7', '59_8'],
4: ['67_1', '67_10', '67_11', '67_2', '67_3', '67_4', '67_5', '67_6', '67_7', '67_8', '67_9', '68_1', '68_10',
'68_11', '68_2', '68_3', '68_4', '68_5', '68_6', '68_7', '68_8', '68_9', '69_1', '69_10', '69_2', '69_3',
'69_4', '69_5', '69_6', '69_7', '69_8', '69_9'],
5: ['61_1', '61_2', '61_3', '61_4', '61_5', '61_6', '61_7', '61_8', '75_1', '75_2', '75_3', '75_4', '75_5',
'75_6',
'75_8', '77_1', '77_2', '77_3', '77_4', '77_5', '77_6', '77_7', '77_8', '78_1', '78_2', '78_3', '78_4',
'78_5',
'78_6', '78_7', '78_8', '79_1', '79_2', '79_3', '79_4', '79_5', '79_6', '79_7', '79_8', '79_9']}
elif dataset=="cast-20":
# We split cast-20 according to https://github.com/thunlp/ConvDR/blob/main/data/preprocess_cast20.py
foldid2qid={1: ['81_1', '81_2', '81_3', '81_4', '81_5', '81_6', '81_7', '81_8', '82_1', '82_10', '82_2', '82_3', '82_4',
'82_5', '82_6', '82_7', '82_8', '82_9', '83_1', '83_2', '83_3', '83_4', '83_5', '83_6', '83_7', '83_8',
'84_1', '84_2', '84_3', '84_4', '84_5', '84_6', '85_1', '85_2', '85_3', '85_4', '85_5', '85_6', '85_7',
'85_8', '85_9'],
2: ['86_1', '86_2', '86_3', '86_4', '86_5', '86_6', '86_7', '87_1', '87_2', '87_3', '87_4', '87_5', '87_7',
'87_8', '87_9', '88_1', '88_10', '88_2', '88_3', '88_4', '88_5', '88_6', '88_7', '88_8', '88_9', '89_1',
'89_10', '89_11', '89_2', '89_3', '89_4', '89_5', '89_6', '89_7', '89_8', '89_9', '90_1', '90_2', '90_3',
'90_4', '90_5', '90_6', '90_7', '90_8'],
3: ['91_1', '91_2', '91_3', '91_4', '91_5', '91_6', '91_7', '91_8', '92_1', '92_2', '92_3', '92_4', '92_5',
'92_6', '92_7', '93_1', '93_2', '93_3', '93_4', '93_5', '93_6', '94_1', '94_2', '94_3', '94_4', '94_5',
'94_6', '94_7', '94_8', '95_1', '95_2', '95_3', '95_4', '95_5', '95_6', '95_7', '95_8'],
4: ['100_1', '100_2', '100_3', '100_4', '100_5', '100_6', '100_7', '100_8', '96_1', '96_3', '96_4', '96_5',
'96_6', '96_7', '96_8', '97_1', '97_2', '97_3', '97_4', '97_5', '97_6', '97_7', '97_8', '98_1', '98_2',
'98_3', '98_4', '98_5', '98_6', '98_7', '98_8', '99_1', '99_2', '99_3', '99_4', '99_5', '99_6', '99_7',
'99_8'],
5: ['101_1', '101_10', '101_2', '101_3', '101_4', '101_5', '101_6', '101_7', '101_8', '101_9', '102_1',
'102_2', '102_3', '102_4', '102_5', '102_6', '102_7', '102_8', '102_9', '103_1', '103_10', '103_2',
'103_3', '103_4', '103_5', '103_6', '103_8', '103_9', '104_1', '104_10', '104_12', '104_13', '104_3',
'104_4', '104_6', '104_7', '104_8', '104_9', '105_1', '105_2', '105_3', '105_4', '105_5', '105_6', '105_7',
'105_8', '105_9']}
else:
raise Exception
return foldid2qid