-
Notifications
You must be signed in to change notification settings - Fork 2
/
redis_client.py
119 lines (94 loc) · 3.54 KB
/
redis_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import sys
import io
import redis
import torch
from tqdm.auto import tqdm
ver = sys.version_info
if ver >= (3, 8):
PICKLE_VERSION = 5
else:
PICKLE_VERSION = 4
CXN = redis.ConnectionPool(host='localhost', port=6379, db=0)
class RedisListObject:
def __init__(self, name):
self.name = name
def __len__(self):
with redis.StrictRedis(connection_pool=CXN) as rdb:
return rdb.llen(self.name)
def __setitem__(self, index, value):
with redis.StrictRedis(connection_pool=CXN) as rdb:
if index >= rdb.llen(self.name):
raise IndexError
with io.BytesIO() as buf:
torch.save(value, buf, pickle_protocol=PICKLE_VERSION, _use_new_zipfile_serialization=True)
if PICKLE_VERSION >= 5:
rdb.lset(self.name, index, buf.getbuffer())
else:
rdb.lset(self.name, index, buf.getvalue())
def __getitem__(self, index):
with redis.StrictRedis(connection_pool=CXN) as rdb:
if not rdb.exists(self.name):
raise redis.DataError(f'Dataset named {self.name} does not exist')
if index >= rdb.llen(self.name):
raise IndexError
with io.BytesIO(rdb.lindex(self.name, index)) as buf:
return torch.load(buf)
def append(self, value):
with io.BytesIO() as buf:
torch.save(value, buf, pickle_protocol=PICKLE_VERSION, _use_new_zipfile_serialization=True)
#print(len(buf.getvalue()))
with redis.StrictRedis(connection_pool=CXN) as rdb:
func = rdb.rpush if rdb.exists(self.name) else rdb.lpush
if PICKLE_VERSION >= 5:
func(self.name, buf.getbuffer())
else:
func(self.name, buf.getvalue())
def delete(self):
with redis.StrictRedis(connection_pool=CXN) as rdb:
if rdb.exists(self.name):
rdb.delete(self.name)
else:
raise redis.DataError(f'Dataset named {self.name} does not exist')
class RedisClient:
def get(self, key):
with redis.StrictRedis(connection_pool=CXN) as rdb:
if rdb.exists(key):
return RedisListObject(key)
else:
raise redis.DataError(f'Dataset named {key} does not exist')
def set_data_list(self, key, values):
try:
obj = self.get(key)
obj.delete()
except:
obj = RedisListObject(key)
for item in tqdm(values, desc=f"storing {key}", dynamic_ncols=True):
obj.append(item)
def keys(self):
with redis.StrictRedis(connection_pool=CXN) as rdb:
return rdb.keys()
def stats(self):
with redis.StrictRedis(connection_pool=CXN) as rdb:
try:
return rdb.memory_stats()
except:
return rdb.execute_command('MEMORY STATS')
def check_lens(self, nums):
try:
for k, v in nums.items():
obj = self.get(k)
if v != 0 and len(obj):
return False
except:
return False
def flushdb(self):
with redis.StrictRedis(connection_pool=CXN) as rdb:
rdb.flushdb()
if __name__ == "__main__":
c = RedisClient()
print(c.stats())
data_list = [tuple(torch.rand(10, 10) for _ in range(10)) for _ in range(10)]
c.set_data_list("test", data_list)
print(c.get("test")[0], c.get("test")[1])
c.flushdb()
print(c.stats())