-
Notifications
You must be signed in to change notification settings - Fork 0
/
Batching.py
59 lines (50 loc) · 1.56 KB
/
Batching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math
seed = 0
print()
# read csv
results = pd.read_csv('GoogleSearch.csv')
print('total number of results: %d'%(len(results)))
# exclude len(phrases) > 3 words)
results = results[results['cues'].str.count(' ')<3]
print('number of cues (with results) of length less than 4 words: %d'%(len(results)))
# log (base 10) bin results
results['log'] = np.log10(results['result']+results['compare'])
bin_max = math.ceil(results['log'].max())
hist, bins = np.histogram(results['log'], bins=range(bin_max+1))
# plot histogram
ax = results['log'].plot.hist(log=True, bins=bins, edgecolor='black')
ax.set_xlabel('log (base-10) of number of Google Search results')
ax.set_ylabel('number of cues per bin')
plt.savefig('histogram.png')
# get batch draft
get = 20
batch = None
print()
for b in bins[:-1]:
_bin = results['cues'][(results['log']>b) & (results['log']<b+1)]
if len(_bin) < get:
_get = len(_bin)
else:
_get = get
print('number in bin 10^%d to 10^%d: %d'%(b, b+1, _get))
_bin = _bin.sample(n=_get, replace=False, random_state=seed).to_list()
if batch is None:
batch = _bin.copy()
else:
batch += _bin
print()
print('total number in batch draft: %d'%(len(batch)))
## filter place names
#places = pd.read_csv('STR.csv')
#for place in places['feature']:
# for cue in batch:
# if place in cue:
# batch.remove(cue)
#print(len(batch))
# save batch draft
batch = pd.DataFrame(batch)
batch.to_csv('batch.csv', index=False, header=False)
print()