Skip to content

Commit

Permalink
[feature] add client demo for sending request to easyrec processor (#363
Browse files Browse the repository at this point in the history
)

* add client demo for sending request to easyrec processor
* fix code style
  • Loading branch information
chengmengli06 authored May 6, 2023
1 parent adc5f25 commit 135671a
Show file tree
Hide file tree
Showing 17 changed files with 361 additions and 42 deletions.
1 change: 1 addition & 0 deletions .git_bin_path
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{"leaf_name": "data/test", "leaf_file": ["data/test/batch_criteo_sample.tfrecord", "data/test/criteo_sample.tfrecord", "data/test/dwd_avazu_ctr_deepmodel_10w.csv", "data/test/embed_data.csv", "data/test/lookup_data.csv", "data/test/tag_kv_data.csv", "data/test/test.csv", "data/test/test_sample_weight.txt", "data/test/test_with_quote.csv"]}
{"leaf_name": "data/test/client", "leaf_file": ["data/test/client/item_lst", "data/test/client/user_table_data", "data/test/client/user_table_schema"]}
{"leaf_name": "data/test/criteo_data", "leaf_file": ["data/test/criteo_data/category.bin", "data/test/criteo_data/dense.bin", "data/test/criteo_data/label.bin", "data/test/criteo_data/readme"]}
{"leaf_name": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "leaf_file": ["data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/ESTIMATOR_TRAIN_DONE", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/atexit_sync_1661483067", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/checkpoint", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/eval_result.txt", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.index", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/model.ckpt-1000.meta", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/pipeline.config", "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls/version"]}
{"leaf_name": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "leaf_file": ["data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/checkpoint", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/eval_result.txt", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.data-00000-of-00001", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.index", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/model.ckpt-1000.meta", "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt/pipeline.config"]}
Expand Down
1 change: 1 addition & 0 deletions .git_bin_url
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{"leaf_path": "data/test", "sig": "656d73b4e78d0d71e98120050bc51387", "remote_path": "data/git_oss_sample_data/data_test_656d73b4e78d0d71e98120050bc51387"}
{"leaf_path": "data/test/client", "sig": "d2e000187cebd884ee10e3cf804717fc", "remote_path": "data/git_oss_sample_data/data_test_client_d2e000187cebd884ee10e3cf804717fc"}
{"leaf_path": "data/test/criteo_data", "sig": "f224ba0b1a4f66eeda096c88703d3afc", "remote_path": "data/git_oss_sample_data/data_test_criteo_data_f224ba0b1a4f66eeda096c88703d3afc"}
{"leaf_path": "data/test/distribute_eval_test/deepfm_distribute_eval_dwd_avazu_out_multi_cls", "sig": "2bc0c12a09e1f4c39f839972cf09674b", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_deepfm_distribute_eval_dwd_avazu_out_multi_cls_2bc0c12a09e1f4c39f839972cf09674b"}
{"leaf_path": "data/test/distribute_eval_test/dropoutnet_distribute_eval_taobao_ckpt", "sig": "9fde5d2987654f268a231a1c69db5799", "remote_path": "data/git_oss_sample_data/data_test_distribute_eval_test_dropoutnet_distribute_eval_taobao_ckpt_9fde5d2987654f268a231a1c69db5799"}
Expand Down
50 changes: 26 additions & 24 deletions easy_rec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import logging
import os
import platform
import sys

import tensorflow as tf

from easy_rec.version import __version__

curr_dir, _ = os.path.split(__file__)
Expand All @@ -16,33 +15,36 @@
logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')

if platform.system() == 'Linux':
ops_dir = os.path.join(curr_dir, 'python/ops')
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
ops_dir = os.path.join(ops_dir, '1.15')
# Avoid import tensorflow which conflicts with the version used in EasyRecProcessor
if 'PROCESSOR_TEST' not in os.environ:
if platform.system() == 'Linux':
ops_dir = os.path.join(curr_dir, 'python/ops')
import tensorflow as tf
if 'PAI' in tf.__version__:
ops_dir = os.path.join(ops_dir, '1.12_pai')
elif tf.__version__.startswith('1.12'):
ops_dir = os.path.join(ops_dir, '1.12')
elif tf.__version__.startswith('1.15'):
ops_dir = os.path.join(ops_dir, '1.15')
else:
ops_dir = None
else:
ops_dir = None
else:
ops_dir = None

from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export # isort:skip # noqa: E402
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402
from easy_rec.python.inference.predictor import Predictor # isort:skip # noqa: E402
from easy_rec.python.main import evaluate # isort:skip # noqa: E402
from easy_rec.python.main import distribute_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export # isort:skip # noqa: E402
from easy_rec.python.main import train_and_evaluate # isort:skip # noqa: E402
from easy_rec.python.main import export_checkpoint # isort:skip # noqa: E402

try:
import tensorflow_io.oss
except Exception:
pass
try:
import tensorflow_io.oss
except Exception:
pass

print('easy_rec version: %s' % __version__)
print('Usage: easy_rec.help()')
print('easy_rec version: %s' % __version__)
print('Usage: easy_rec.help()')

_global_config = {}

Expand Down
38 changes: 38 additions & 0 deletions easy_rec/python/inference/client/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# EasyRecProcessor Client

Demo

```bash
python -m easy_rec.python.client.client_demo \
--endpoint 1301055xxxxxxxxx.cn-hangzhou.pai-eas.aliyuncs.com \
--service_name ali_rec_rnk_sample_rt_v3 \
--token MmQ3Yxxxxxxxxxxx \
--table_schema data/test/client/user_table_schema \
--table_data data/test/client/user_table_data \
--item_lst data/test/client/item_lst

# output:
# results {
# key: "item_0"
# value {
# scores: 0.0
# scores: 0.0
# }
# }
# results {
# key: "item_1"
# value {
# scores: 0.0
# scores: 0.0
# }
# }
# results {
# key: "item_2"
# value {
# scores: 0.0
# scores: 0.0
# }
# }
# outputs: "probs_is_click"
# outputs: "probs_is_go"
```
135 changes: 135 additions & 0 deletions easy_rec/python/inference/client/client_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import logging
import sys
import traceback

from easyrec_request import EasyrecRequest

from easy_rec.python.protos.predict_pb2 import PBFeature
from easy_rec.python.protos.predict_pb2 import PBRequest

logging.basicConfig(
level=logging.INFO, format='[%(asctime)s][%(levelname)s] %(message)s')

try:
from eas_prediction import PredictClient # TFRequest
except Exception:
logging.error('eas_prediction is not installed: pip install eas-prediction')
sys.exit(1)


def build_request(table_cols, table_data, item_ids=None):
request_pb = PBRequest()
assert isinstance(table_data, list)
try:
for col_id in range(len(table_cols)):
cname, dtype = table_cols[col_id]
value = table_data[col_id]
feat = PBFeature()
if value is None:
continue
if dtype == 'STRING':
feat.string_feature = value
elif dtype in ('FLOAT', 'DOUBLE'):
feat.float_feature = value
elif dtype == 'BIGINT':
feat.long_feature = value
elif dtype == 'INT':
feat.int_feature = value

request_pb.user_features[cname].CopyFrom(feat)
except Exception:
traceback.print_exc()
sys.exit()
request_pb.item_ids.extend(item_ids)
return request_pb


def parse_table_schema(create_table_sql):
create_table_sql = create_table_sql.lower()
spos = create_table_sql.index('(')
epos = create_table_sql[spos + 1:].index(')')
cols = create_table_sql[(spos + 1):epos]
cols = [x.strip().lower() for x in cols.split(',')]
col_info_arr = []
for col in cols:
col = [k for k in col.split() if k != '']
assert len(col) == 2
col[1] = col[1].upper()
col_info_arr.append(col)
return col_info_arr


def send_request(req_pb, client, debug_level=0):
req = EasyrecRequest()
req.add_feed(req_pb, debug_level)
tmp = client.predict(req)
return tmp


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--endpoint',
type=str,
default=None,
help='eas endpoint, such as 12345.cn-beijing.pai-eas.aliyuncs.com')
parser.add_argument(
'--service_name', type=str, default=None, help='eas service name')
parser.add_argument(
'--token', type=str, default=None, help='eas service token')
parser.add_argument(
'--table_schema',
type=str,
default=None,
help='user feature table schema path')
parser.add_argument(
'--table_data',
type=str,
default=None,
help='user feature table data path')
parser.add_argument('--item_lst', type=str, default=None, help='item list')

args, _ = parser.parse_known_args()

if args.endpoint is None:
logging.error('--endpoint is not set')
sys.exit(1)
if args.service_name is None:
logging.error('--service_name is not set')
sys.exit(1)
if args.token is None:
logging.error('--token is not set')
sys.exit(1)
if args.table_schema is None:
logging.error('--table_schema is not set')
sys.exit(1)
if args.table_data is None:
logging.error('--table_data is not set')
sys.exit(1)
if args.item_lst is None:
logging.error('--item_lst is not set')
sys.exit(1)

client = PredictClient(args.endpoint, args.service_name)
client.set_token(args.token)
client.init()

with open(args.table_schema, 'r') as fin:
create_table_sql = fin.read().strip()

with open(args.table_data, 'r') as fin:
table_data = fin.read().strip()

table_cols = parse_table_schema(create_table_sql)
table_data = table_data.split(';')

with open(args.item_lst, 'r') as fin:
items = fin.read().strip()
items = items.split(',')

req = build_request(table_cols, table_data, item_ids=items)
resp = send_request(req, client)
logging.info(resp)
72 changes: 72 additions & 0 deletions easy_rec/python/inference/client/easyrec_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from eas_prediction.request import Request

from easy_rec.python.protos.predict_pb2 import PBRequest
from easy_rec.python.protos.predict_pb2 import PBResponse

# from eas_prediction.request import Response


class EasyrecRequest(Request):
"""Request for tensorflow services whose input data is in format of protobuf.
This class privide methods to fill generate PBRequest and parse PBResponse.
"""

def __init__(self, signature_name=None):
self.request_data = PBRequest()
self.signature_name = signature_name

def __str__(self):
return self.request_data

def set_signature_name(self, singature_name):
"""Set the signature name of the model.
Args:
singature_name: signature name of the model
"""
self.signature_name = singature_name

def add_feed(self, data, dbg_lvl=0):
if not isinstance(data, PBRequest):
self.request_data.ParseFromString(data)
else:
self.request_data = data
self.request_data.debug_level = dbg_lvl

def add_user_fea_flt(self, k, v):
self.request_data.user_features[k].float_feature = float(v)

def add_user_fea_s(self, k, v):
self.request_data.user_features[k].string_feature = str(v)

def set_faiss_neigh_num(self, neigh_num):
self.request_data.faiss_neigh_num = neigh_num

def keep_one_item_ids(self):
item_id = self.request_data.item_ids[0]
self.request_data.ClearField('item_ids')
self.request_data.item_ids.extend([item_id])

def to_string(self):
"""Serialize the request to string for transmission.
Returns:
the request data in format of string
"""
return self.request_data.SerializeToString()

def parse_response(self, response_data):
"""Parse the given response data in string format to the related TFResponse object.
Args:
response_data: the service response data in string format
Returns:
the TFResponse object related the request
"""
self.response = PBResponse()
self.response.ParseFromString(response_data)
return self.response
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def build_array_proto(array_proto, data, dtype):
'--test_dir', type=str, default=None, help='test directory')
args = parser.parse_args()

if not os.path.exists('processor'):
os.mkdir('processor')
if not os.path.exists(PROCESSOR_ENTRY_LIB):
if not os.path.exists('processor/' + PROCESSOR_FILE):
subprocess.check_output(
Expand Down
9 changes: 5 additions & 4 deletions easy_rec/python/layers/multihead_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,12 @@ def embedding_postprocessor(input_tensor,
if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
with tf.variable_scope("position_embedding", reuse=reuse_position_embedding):
with tf.variable_scope(
'position_embedding', reuse=reuse_position_embedding):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
Expand Down
Loading

0 comments on commit 135671a

Please sign in to comment.