-
Notifications
You must be signed in to change notification settings - Fork 325
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] add client demo for sending request to easyrec processor (#363
- Loading branch information
1 parent
adc5f25
commit 135671a
Showing
17 changed files
with
361 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# EasyRecProcessor Client | ||
|
||
Demo | ||
|
||
```bash | ||
python -m easy_rec.python.client.client_demo \ | ||
--endpoint 1301055xxxxxxxxx.cn-hangzhou.pai-eas.aliyuncs.com \ | ||
--service_name ali_rec_rnk_sample_rt_v3 \ | ||
--token MmQ3Yxxxxxxxxxxx \ | ||
--table_schema data/test/client/user_table_schema \ | ||
--table_data data/test/client/user_table_data \ | ||
--item_lst data/test/client/item_lst | ||
|
||
# output: | ||
# results { | ||
# key: "item_0" | ||
# value { | ||
# scores: 0.0 | ||
# scores: 0.0 | ||
# } | ||
# } | ||
# results { | ||
# key: "item_1" | ||
# value { | ||
# scores: 0.0 | ||
# scores: 0.0 | ||
# } | ||
# } | ||
# results { | ||
# key: "item_2" | ||
# value { | ||
# scores: 0.0 | ||
# scores: 0.0 | ||
# } | ||
# } | ||
# outputs: "probs_is_click" | ||
# outputs: "probs_is_go" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# -*- encoding:utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import argparse | ||
import logging | ||
import sys | ||
import traceback | ||
|
||
from easyrec_request import EasyrecRequest | ||
|
||
from easy_rec.python.protos.predict_pb2 import PBFeature | ||
from easy_rec.python.protos.predict_pb2 import PBRequest | ||
|
||
logging.basicConfig( | ||
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s') | ||
|
||
try: | ||
from eas_prediction import PredictClient # TFRequest | ||
except Exception: | ||
logging.error('eas_prediction is not installed: pip install eas-prediction') | ||
sys.exit(1) | ||
|
||
|
||
def build_request(table_cols, table_data, item_ids=None): | ||
request_pb = PBRequest() | ||
assert isinstance(table_data, list) | ||
try: | ||
for col_id in range(len(table_cols)): | ||
cname, dtype = table_cols[col_id] | ||
value = table_data[col_id] | ||
feat = PBFeature() | ||
if value is None: | ||
continue | ||
if dtype == 'STRING': | ||
feat.string_feature = value | ||
elif dtype in ('FLOAT', 'DOUBLE'): | ||
feat.float_feature = value | ||
elif dtype == 'BIGINT': | ||
feat.long_feature = value | ||
elif dtype == 'INT': | ||
feat.int_feature = value | ||
|
||
request_pb.user_features[cname].CopyFrom(feat) | ||
except Exception: | ||
traceback.print_exc() | ||
sys.exit() | ||
request_pb.item_ids.extend(item_ids) | ||
return request_pb | ||
|
||
|
||
def parse_table_schema(create_table_sql): | ||
create_table_sql = create_table_sql.lower() | ||
spos = create_table_sql.index('(') | ||
epos = create_table_sql[spos + 1:].index(')') | ||
cols = create_table_sql[(spos + 1):epos] | ||
cols = [x.strip().lower() for x in cols.split(',')] | ||
col_info_arr = [] | ||
for col in cols: | ||
col = [k for k in col.split() if k != ''] | ||
assert len(col) == 2 | ||
col[1] = col[1].upper() | ||
col_info_arr.append(col) | ||
return col_info_arr | ||
|
||
|
||
def send_request(req_pb, client, debug_level=0): | ||
req = EasyrecRequest() | ||
req.add_feed(req_pb, debug_level) | ||
tmp = client.predict(req) | ||
return tmp | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
'--endpoint', | ||
type=str, | ||
default=None, | ||
help='eas endpoint, such as 12345.cn-beijing.pai-eas.aliyuncs.com') | ||
parser.add_argument( | ||
'--service_name', type=str, default=None, help='eas service name') | ||
parser.add_argument( | ||
'--token', type=str, default=None, help='eas service token') | ||
parser.add_argument( | ||
'--table_schema', | ||
type=str, | ||
default=None, | ||
help='user feature table schema path') | ||
parser.add_argument( | ||
'--table_data', | ||
type=str, | ||
default=None, | ||
help='user feature table data path') | ||
parser.add_argument('--item_lst', type=str, default=None, help='item list') | ||
|
||
args, _ = parser.parse_known_args() | ||
|
||
if args.endpoint is None: | ||
logging.error('--endpoint is not set') | ||
sys.exit(1) | ||
if args.service_name is None: | ||
logging.error('--service_name is not set') | ||
sys.exit(1) | ||
if args.token is None: | ||
logging.error('--token is not set') | ||
sys.exit(1) | ||
if args.table_schema is None: | ||
logging.error('--table_schema is not set') | ||
sys.exit(1) | ||
if args.table_data is None: | ||
logging.error('--table_data is not set') | ||
sys.exit(1) | ||
if args.item_lst is None: | ||
logging.error('--item_lst is not set') | ||
sys.exit(1) | ||
|
||
client = PredictClient(args.endpoint, args.service_name) | ||
client.set_token(args.token) | ||
client.init() | ||
|
||
with open(args.table_schema, 'r') as fin: | ||
create_table_sql = fin.read().strip() | ||
|
||
with open(args.table_data, 'r') as fin: | ||
table_data = fin.read().strip() | ||
|
||
table_cols = parse_table_schema(create_table_sql) | ||
table_data = table_data.split(';') | ||
|
||
with open(args.item_lst, 'r') as fin: | ||
items = fin.read().strip() | ||
items = items.split(',') | ||
|
||
req = build_request(table_cols, table_data, item_ids=items) | ||
resp = send_request(req, client) | ||
logging.info(resp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
from eas_prediction.request import Request | ||
|
||
from easy_rec.python.protos.predict_pb2 import PBRequest | ||
from easy_rec.python.protos.predict_pb2 import PBResponse | ||
|
||
# from eas_prediction.request import Response | ||
|
||
|
||
class EasyrecRequest(Request): | ||
"""Request for tensorflow services whose input data is in format of protobuf. | ||
This class privide methods to fill generate PBRequest and parse PBResponse. | ||
""" | ||
|
||
def __init__(self, signature_name=None): | ||
self.request_data = PBRequest() | ||
self.signature_name = signature_name | ||
|
||
def __str__(self): | ||
return self.request_data | ||
|
||
def set_signature_name(self, singature_name): | ||
"""Set the signature name of the model. | ||
Args: | ||
singature_name: signature name of the model | ||
""" | ||
self.signature_name = singature_name | ||
|
||
def add_feed(self, data, dbg_lvl=0): | ||
if not isinstance(data, PBRequest): | ||
self.request_data.ParseFromString(data) | ||
else: | ||
self.request_data = data | ||
self.request_data.debug_level = dbg_lvl | ||
|
||
def add_user_fea_flt(self, k, v): | ||
self.request_data.user_features[k].float_feature = float(v) | ||
|
||
def add_user_fea_s(self, k, v): | ||
self.request_data.user_features[k].string_feature = str(v) | ||
|
||
def set_faiss_neigh_num(self, neigh_num): | ||
self.request_data.faiss_neigh_num = neigh_num | ||
|
||
def keep_one_item_ids(self): | ||
item_id = self.request_data.item_ids[0] | ||
self.request_data.ClearField('item_ids') | ||
self.request_data.item_ids.extend([item_id]) | ||
|
||
def to_string(self): | ||
"""Serialize the request to string for transmission. | ||
Returns: | ||
the request data in format of string | ||
""" | ||
return self.request_data.SerializeToString() | ||
|
||
def parse_response(self, response_data): | ||
"""Parse the given response data in string format to the related TFResponse object. | ||
Args: | ||
response_data: the service response data in string format | ||
Returns: | ||
the TFResponse object related the request | ||
""" | ||
self.response = PBResponse() | ||
self.response.ParseFromString(response_data) | ||
return self.response |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.