-
Notifications
You must be signed in to change notification settings - Fork 4
/
storage.py
196 lines (153 loc) · 6.19 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from abc import ABCMeta, abstractmethod
import logger
from model import Model
import pymongo
from pymongo import MongoClient
from pyspark.mllib.common import _py2java
from pyspark.mllib.recommendation import MatrixFactorizationModel
import time
import urlparse
class ModelFactory:
"""Model store factory for concrete backends
Returns a concrete instance of a model store backend using a connection URL
"""
@staticmethod
def fromURL(sc, url):
"""Parse a model store connection URL and returns the appropriate
model store backend
:param sc: A Spark context
:type sc: SparkContext
:param url: The model store connection URL
:type url: str
:return: ModelReader
"""
_parsed_url = urlparse.urlparse(url)
if _parsed_url[0].lower() == 'mongodb':
return MongoModelReader(sc=sc, uri=url)
class ModelReader:
"""
Abstract class to read model from a model store.
Implement backend specific model stores as a subclass.
"""
__metaclass__ = ABCMeta
def __init__(self, sc, uri):
"""
:param sc: A Spark context
:param uri: The connection URI
"""
self._sc = sc
self._url = uri
self._logger = logger.get_logger()
@abstractmethod
def read(self, version):
"""
Read a specific model version from the model store
:param version: unique model identifier
:return: A `Model` instance
"""
pass
@abstractmethod
def readLatest(self):
"""
Read the latest model from the model store
:return: A `Model` instance
"""
pass
def instantiate(self, rank, version, userFeatures, productFeatures):
"""
Instantiate an ALS `Model` from the model store data
:param rank: ALS rank
:param version: `Model` version
:param userFeatures: A list of user features
:param productFeatures: A list of product features
:return: A `Model` instance
"""
jvm = self._sc._gateway.jvm
als_model = jvm.io.radanalytics.als.ALSSerializer.instantiateModel(
rank, userFeatures,
productFeatures)
wrapper = jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper( # noqa: E501
als_model)
model = Model(sc=self._sc, als_model=MatrixFactorizationModel(wrapper),
version=version, data_version=1)
return model
class MongoModelReader(ModelReader):
"""This class allows reading serialized ALS models from a MongoDB backend
"""
def __init__(self, sc, uri):
super(MongoModelReader, self).__init__(sc=sc, uri=uri)
client = MongoClient(self._url)
self._logger.debug(
"Initializing a MongoDB model reader (at {})".format(uri))
self._db = client.models
def read(self, version):
# current time for model loading benchmark
start_time = time.time()
# read a serialized model with a specific version id
data = list(self._db.models.find({'id': version}))
# read the model's rank metadata
rank = data[0]['rank']
# transform the read latent factor from list to a Spark's RDD
userFactors = _py2java(self._sc, self._sc.parallelize(
self.extractFeatures(
list(self._db.userFactors.find({'model_id': version})))))
productFactors = _py2java(self._sc, self._sc.parallelize(
self.extractFeatures(
list(self._db.productFactors.find({'model_id': version})))))
# instantiate a Spark's `MatrixFactorizationModel` from the
# latent factors RDDs
_instantiated = self.instantiate(rank=rank,
version=version,
userFeatures=userFactors,
productFeatures=productFactors)
# time elapsed for model loading
elapsed_time = time.time() - start_time
self._logger.info(
"Model version {0} took {1} seconds to instantiate".format(version,
elapsed_time)) # noqa: E501
return _instantiated
def readLatest(self):
data = list(self._db.models.find().sort('created', pymongo.DESCENDING))
version = data[0]['id']
self._logger.debug("Latest model found has id={}.".format(version))
rank = data[0]['rank']
userFactors = _py2java(self._sc, self._sc.parallelize(
self.extractFeatures(
list(self._db.userFactors.find({'model_id': version})))))
productFactors = _py2java(self._sc, self._sc.parallelize(
self.extractFeatures(
list(self._db.productFactors.find({'model_id': version})))))
return self.instantiate(rank=rank,
version=version,
userFeatures=userFactors,
productFeatures=productFactors)
def latestId(self):
"""
Reads the most recent model's id from the MongoDB model store
:return: A model version
:rtype: int
"""
data = list(self._db.models.find().sort('created', pymongo.DESCENDING))
version = data[0]['id']
self._logger.debug("Latest model found has id={}.".format(version))
return version
@staticmethod
def extractFeatures(data):
"""
Format the read latent features into a list of `id`, `feature` tuples
:param data: Latent feature read from MongoDB
:return: List[Tuple[int, List[int]]]
"""
return [(item['id'], item['features'],) for item in data]
class ParquetModelReader(ModelReader):
"""This class allows reading serialized ALS models from a Parquet file
"""
def __init__(self, sc, uri):
super(ParquetModelReader, self).__init__(sc=sc, uri=uri)
def read(self, version):
als_model = MatrixFactorizationModel.load(self._sc, self._url)
model = Model(sc=self._sc, als_model=als_model, version=version,
data_version=1)
return model
def readLatest(self):
pass